31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
84 size_t numIm2ColOffsets,
86 if (tensorDims < 1 || tensorDims > 5)
87 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
95 "to use im2col mode, the tensor has to be at least 3-dimensional");
97 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
99 loc,
"im2col offsets must be 2 less than number of coordinates");
105 size_t numIm2ColOffsets = getIm2colOffsets().size();
106 bool isIm2Col = numIm2ColOffsets > 0;
108 numIm2ColOffsets, getLoc());
113 return emitError(
"Maximum 5 coordinates and dimension is supported.");
118 if (getModifier() != LoadCacheModifierKind::CG &&
119 getModifier() != LoadCacheModifierKind::CA)
120 return emitError(
"Only CG and CA cache modifiers are supported.");
121 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122 return emitError(
"expected byte size to be either 4, 8 or 16.");
123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
129 size_t numIm2ColOffsets = getIm2colOffsets().size();
130 bool isIm2Col = numIm2ColOffsets > 0;
132 numIm2ColOffsets, getLoc());
136 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
142 using RndMode = NVVM::FPRoundingMode;
146 return emitError(
"Relu not supported with rna rounding mode.");
153 "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
161 std::optional<mlir::NVVM::MMATypes>
162 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
165 if (operandElType.
isF64())
166 return NVVM::MMATypes::f64;
167 if (operandElType.
isF16() || operandElType == half2Type)
168 return NVVM::MMATypes::f16;
169 if (operandElType.
isF32() && isAccumulator)
170 return NVVM::MMATypes::f32;
171 if (operandElType.
isF32() && !isAccumulator)
172 return NVVM::MMATypes::tf32;
173 if (llvm::isa<IntegerType>(operandElType)) {
175 return NVVM::MMATypes::s32;
179 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
180 if (structType.getBody().empty())
182 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
189 return (type == MMATypes::u4 || type == MMATypes::s4);
193 return (type == MMATypes::u8 || type == MMATypes::s8);
198 type == MMATypes::s32;
201 MMATypes MmaOp::accumPtxType() {
202 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
203 getODSOperands(2).getTypes().front(),
true);
204 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
208 MMATypes MmaOp::resultPtxType() {
209 std::optional<mlir::NVVM::MMATypes> val =
210 inferOperandMMAType(getResult().
getType(),
true);
211 assert(val.has_value() &&
"result PTX type should always be inferrable");
217 struct OperandFragment {
218 StringRef operandName;
219 StringRef ptxTypeAttr;
221 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
222 : operandName(name), ptxTypeAttr(ptxTypeName) {}
225 std::array<OperandFragment, 3> frags{
226 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
227 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
228 OperandFragment(
"C",
"")};
230 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
232 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
233 auto &frag = frags[fragIdx];
234 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
235 for (
auto operandIdx = varOperandSpec.first;
236 operandIdx < varOperandSpec.first + varOperandSpec.second;
238 frag.regs.push_back(this->getOperand(operandIdx));
239 if (operandIdx == 0) {
240 regTypes.push_back(this->getOperand(operandIdx).
getType());
243 std::optional<MMATypes> inferredType =
244 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
246 ignoreAttrNames.push_back(frag.ptxTypeAttr);
249 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
250 p <<
" " << frag.operandName;
256 for (
const auto &frag : frags) {
257 printMmaOperand(frag);
275 std::optional<MMAIntOverflow> intOverflow,
276 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
277 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
279 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
282 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
288 if (multiplicandPtxTypes) {
294 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
296 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
300 if (multiplicandLayouts) {
310 if (intOverflow.has_value())
313 if (b1Op.has_value())
318 MmaOp::getOperandSegmentSizeAttr(),
320 static_cast<int32_t>(operandB.size()),
321 static_cast<int32_t>(operandC.size())}));
329 struct OperandFragment {
330 std::optional<MMATypes> elemtype;
336 std::array<OperandFragment, 4> frags;
341 auto parseMmaOperand = [&](StringRef operandName,
342 OperandFragment &frag) -> LogicalResult {
353 if (parseMmaOperand(
"A", frags[0]).failed())
355 if (parseMmaOperand(
"B", frags[1]).failed())
357 if (parseMmaOperand(
"C", frags[2]).failed())
372 if (operandTypes.size() != 3)
375 "expected one type for each operand segment but got " +
376 Twine(operandTypes.size()) +
" types");
378 auto &frag = frags[iter.index()];
379 frag.regTypes.resize(frag.regs.size(), iter.value());
384 inferOperandMMAType(frag.regTypes[0], iter.index() < 2);
390 frags[3].elemtype = inferOperandMMAType(resultType,
true);
392 std::array<StringRef, 2> names{
"multiplicandAPtxType",
393 "multiplicandBPtxType"};
394 for (
unsigned idx = 0; idx < names.size(); idx++) {
395 const auto &frag = frags[idx];
396 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
397 if (!frag.elemtype.has_value() && !attr.has_value()) {
400 "attribute " + names[idx] +
401 " is not provided explicitly and cannot be inferred");
403 if (!attr.has_value())
409 if (!namedAttributes.
empty())
413 static_cast<int32_t>(frags[0].regs.size()),
414 static_cast<int32_t>(frags[1].regs.size()),
415 static_cast<int32_t>(frags[2].regs.size()),
426 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
427 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
430 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
433 auto f16x2x2StructTy =
434 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
436 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
438 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
440 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
441 getShapeAttr().getK()};
447 AllowedShapes allowedShapes;
448 AllowedTypes expectedA;
449 AllowedTypes expectedB;
450 AllowedTypes expectedC;
455 if (mmaShape[0] == 16) {
457 Type multiplicandFragType;
458 switch (*getMultiplicandAPtxType()) {
461 multiplicandFragType = i32Ty;
462 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
463 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
467 multiplicandFragType = i32Ty;
468 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
469 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
473 multiplicandFragType = f16x2Ty;
474 expectedResult.push_back(f16x2x2StructTy);
475 expectedResult.push_back(f32x4StructTy);
489 return emitError(
"invalid shape or multiplicand type: " +
490 stringifyEnum(getMultiplicandAPtxType().value()));
494 expectedResult.push_back(s32x4StructTy);
495 expectedC.emplace_back(4, i32Ty);
496 multiplicandFragType = i32Ty;
498 expectedC.emplace_back(2, f16x2Ty);
499 expectedC.emplace_back(4, f32Ty);
502 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
503 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
504 expectedA.emplace_back(unitA, multiplicandFragType);
505 expectedB.emplace_back(unitB, multiplicandFragType);
506 allowedShapes.push_back({16, 8, kFactor});
507 allowedShapes.push_back({16, 8, kFactor * 2});
511 if (mmaShape[0] == 8) {
512 if (*getMultiplicandAPtxType() == MMATypes::f16) {
513 expectedA.emplace_back(2, f16x2Ty);
514 expectedB.emplace_back(2, f16x2Ty);
515 expectedResult.push_back(f16x2x4StructTy);
516 expectedResult.push_back(f32x8StructTy);
517 expectedC.emplace_back(4, f16x2Ty);
518 expectedC.emplace_back(8, f32Ty);
519 allowedShapes.push_back({8, 8, 4});
521 if (*getMultiplicandAPtxType() == MMATypes::f64) {
523 expectedA.emplace_back(1, f64Ty);
524 expectedB.emplace_back(1, f64Ty);
525 expectedC.emplace_back(2, f64Ty);
527 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
529 allowedShapes.push_back({8, 8, 4});
532 expectedA.push_back({i32Ty});
533 expectedB.push_back({i32Ty});
534 expectedC.push_back({i32Ty, i32Ty});
535 expectedResult.push_back(s32x2StructTy);
537 allowedShapes.push_back({8, 8, 32});
539 allowedShapes.push_back({8, 8, 16});
540 if (getMultiplicandAPtxType().value() == MMATypes::b1)
541 allowedShapes.push_back({8, 8, 128});
545 std::string errorMessage;
546 llvm::raw_string_ostream errorStream(errorMessage);
549 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
550 !llvm::is_contained(allowedShapes, mmaShape)) {
551 errorStream <<
"unimplemented variant for MMA shape <";
552 llvm::interleaveComma(mmaShape, errorStream);
554 return emitOpError(errorMessage);
558 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
561 auto spec = this->getODSOperandIndexAndLength(iter.index());
563 operand_type_begin() + spec.first +
565 bool match = llvm::is_contained(iter.value(), operandTySeg);
568 errorStream <<
"Could not match types for the "
569 << operandNames[iter.index()]
570 <<
" operands; expected one of ";
571 for (
const auto &x : iter.value()) {
572 errorStream << x.size() <<
"x" << x[0] <<
" ";
574 errorStream <<
"but got ";
575 llvm::interleaveComma(operandTySeg, errorStream);
576 return emitOpError(errorMessage);
581 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
582 return expectedResultType == getResult().getType();
585 <<
"Could not match allowed types for the result; expected one of ";
586 llvm::interleaveComma(expectedResult, errorStream);
587 errorStream <<
" but got " << getResult().getType();
588 return emitOpError(errorMessage);
592 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
593 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
601 if (!getIntOverflowBehavior())
602 return emitOpError(
"op requires " +
603 getIntOverflowBehaviorAttrName().strref() +
611 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
613 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
614 auto elementType = (type && type.getBody().size() == 2)
615 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
617 if (!elementType || elementType.getWidth() != 1)
618 return emitError(
"expected return type to be a two-element struct with "
619 "i1 as the second element");
624 NVVM::MMAFrag frag,
int nRow,
627 unsigned numberElements = 0;
631 if (type == NVVM::MMATypes::f16) {
633 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
637 }
else if (type == NVVM::MMATypes::f32) {
640 }
else if (type == NVVM::MMATypes::tf32) {
643 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
645 int parallelSize = 0;
646 if (frag == NVVM::MMAFrag::a)
648 if (frag == NVVM::MMAFrag::b)
652 if (parallelSize == 16)
655 else if (parallelSize == 8)
657 else if (parallelSize == 32)
659 }
else if (type == NVVM::MMATypes::s32) {
663 assert(numberElements != 0 && elementType !=
nullptr);
664 return std::make_pair(elementType, numberElements);
667 static std::pair<mlir::Type, unsigned>
671 if (frag == NVVM::MMAFrag::a) {
674 }
else if (frag == NVVM::MMAFrag::b) {
681 assert(nRow && nCol);
686 unsigned addressSpace =
687 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
690 return emitOpError(
"expected source pointer in memory "
693 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
694 getEltype(), getFrag()) == 0)
695 return emitOpError() <<
"invalid attribute combination";
698 Type dstType = LLVM::LLVMStructType::getLiteral(
701 return emitOpError(
"expected destination type is a structure of ")
702 << typeInfo.second <<
" elements of type " << typeInfo.first;
707 unsigned addressSpace =
708 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
711 return emitOpError(
"expected operands to be a source pointer in memory "
714 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
716 return emitOpError() <<
"invalid attribute combination";
719 if (getArgs().size() != typeInfo.second)
720 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
721 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
722 return operands.
getType() != typeInfo.first;
724 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
729 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
730 getLayoutB(), getEltypeA(),
732 return emitOpError() <<
"invalid attribute combination";
740 arguments.append(typeInfoA.second, typeInfoA.first);
741 arguments.append(typeInfoB.second, typeInfoB.first);
742 arguments.append(typeInfoC.second, typeInfoC.first);
743 unsigned numArgs = arguments.size();
744 if (getArgs().size() != numArgs)
745 return emitOpError() <<
"expected " << numArgs <<
" arguments";
746 for (
unsigned i = 0; i < numArgs; i++) {
747 if (getArgs()[i].
getType() != arguments[i])
748 return emitOpError() <<
"expected argument " << i <<
" to be of type "
751 Type dstType = LLVM::LLVMStructType::getLiteral(
754 return emitOpError(
"expected destination type is a structure of ")
755 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
760 unsigned addressSpace =
761 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
763 return emitOpError(
"expected source pointer in memory space 3");
765 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
766 return emitOpError(
"expected num attribute to be 1, 2 or 4");
769 if (getNum() == 1 &&
getType() != i32)
770 return emitOpError(
"expected destination type is i32");
771 if (getNum() == 2 || getNum() == 4) {
772 Type dstType = LLVM::LLVMStructType::getLiteral(
775 return emitOpError(
"expected destination type is a structure of ")
776 << getNum() <<
" elements of type i32";
782 unsigned addressSpace =
783 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
785 return emitOpError(
"expected source pointer in memory space 3");
787 int numMatrix = getSources().size();
788 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
789 return emitOpError(
"expected num attribute to be 1, 2 or 4");
795 if (typeA == NVVM::WGMMATypes::tf32)
797 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
799 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
801 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
803 if (typeA == NVVM::WGMMATypes::b1)
809 NVVM::WGMMATypes typeA,
810 NVVM::WGMMATypes typeB) {
812 case NVVM::WGMMATypes::f16:
813 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
814 typeB == NVVM::WGMMATypes::f16)
817 case NVVM::WGMMATypes::tf32:
818 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
821 case NVVM::WGMMATypes::u8:
822 case NVVM::WGMMATypes::s8:
823 if (typeD == NVVM::WGMMATypes::s32 &&
824 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
827 case NVVM::WGMMATypes::b1:
828 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
831 case NVVM::WGMMATypes::bf16:
832 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
833 typeB == NVVM::WGMMATypes::bf16)
836 case NVVM::WGMMATypes::e4m3:
837 case NVVM::WGMMATypes::e5m2:
838 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
839 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
842 case WGMMATypes::f32:
843 case WGMMATypes::s32:
844 llvm_unreachable(
"unsupported input types");
852 72, 80, 88, 96, 104, 112, 120, 128,
853 136, 144, 152, 160, 168, 176, 184, 192,
854 200, 208, 216, 224, 232, 240, 248, 256};
856 80, 96, 112, 128, 144, 160,
857 176, 192, 208, 224, 240, 256};
859 case WGMMATypes::f16:
860 case WGMMATypes::tf32:
861 case WGMMATypes::bf16:
862 case WGMMATypes::e4m3:
863 case WGMMATypes::e5m2:
864 if (llvm::is_contained(allowedN, sizeN))
870 if (llvm::is_contained(allowedNshort, sizeN))
873 case WGMMATypes::f32:
874 case WGMMATypes::s32:
875 llvm_unreachable(
"unsupported input types");
882 Value outValue = getResults();
883 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
885 return emitOpError() <<
"expected results to be struct";
886 int outputSize = stype.getBody().size();
887 WGMMATypes typeD = getTypeD();
888 WGMMATypes typeA = getTypeA();
889 WGMMATypes typeB = getTypeB();
891 for (
Type t : stype.getBody()) {
892 if (t != stype.getBody().front())
894 <<
"all elements in struct must be same type but there is " << t;
897 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
898 typeD != WGMMATypes::s32) {
899 return emitOpError() <<
"does not support the given output type "
900 << NVVM::stringifyWGMMATypes(typeD);
902 if (typeD == WGMMATypes::s32 &&
903 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
904 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
908 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
909 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
910 << NVVM::stringifyWGMMATypes(typeB)
911 <<
", it is not supported.";
916 return emitOpError() <<
"shape 'm' must be 64";
920 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
921 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
922 <<
" for input type "
923 << NVVM::stringifyWGMMATypes(typeA);
927 return emitOpError() <<
"has input type "
928 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
929 <<
getShape().getN() <<
", it is not supported.";
936 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
937 (getLayoutA() == mlir::NVVM::MMALayout::col ||
938 getLayoutB() == mlir::NVVM::MMALayout::row)) {
940 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
941 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
942 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
943 << stringifyWGMMATypes(typeB)
944 <<
" requires transpose. However, this is only supported for: "
945 << stringifyMMATypes(MMATypes::f16) <<
" and "
946 << stringifyMMATypes(MMATypes::bf16);
950 int expectedOutput = 0;
951 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
952 expectedOutput =
getShape().getN() / 2;
953 if (typeD == WGMMATypes::f16)
954 expectedOutput =
getShape().getN() / 4;
955 if (outputSize != expectedOutput) {
956 return emitOpError() <<
"results " << expectedOutput
957 <<
", however output struct has " << outputSize
961 if (typeD != WGMMATypes::s32 &&
962 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
963 NVVM::MMAIntOverflow::satfinite) {
965 <<
" `satfinite` can be only used with s32 accumulator, however "
966 "the current accumulator is "
967 << NVVM::stringifyWGMMATypes(typeD);
973 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
976 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
978 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
980 int expectedOutputRegisters = 0;
981 if (getTypeD() == WGMMATypes::f16)
982 expectedOutputRegisters =
getShape().getN() / 4;
984 expectedOutputRegisters =
getShape().getN() / 2;
987 llvm::raw_string_ostream ss(ptx);
992 << ((expectedOutputRegisters * 2) + 2)
994 "wgmma.mma_async.sync.aligned.m"
995 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
996 << stringifyWGMMATypes(getTypeA()) <<
"."
997 << stringifyWGMMATypes(getTypeB());
998 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
999 NVVM::MMAIntOverflow::satfinite)
1003 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1004 ss <<
"$" << regCnt;
1005 if (regCnt != expectedOutputRegisters - 1)
1011 regCnt = (regCnt * 2);
1012 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
1013 if (getTypeD() != WGMMATypes::s32) {
1014 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1018 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1025 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1029 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1036 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1038 if (getTypeD() != WGMMATypes::s32) {
1039 asmValues.push_back(
1040 {makeConstantI32(rewriter,
1041 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1043 asmValues.push_back(
1044 {makeConstantI32(rewriter,
1045 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1049 asmValues.push_back(
1050 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1052 asmValues.push_back(
1053 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1058 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1059 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1060 if (getKind() == NVVM::ProxyKind::GENERIC)
1061 return emitOpError() <<
"generic proxy not a supported proxy kind";
1062 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1063 return emitOpError() <<
"async_shared fence requires space attribute";
1065 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1066 return emitOpError() <<
"only async_shared fence can have space attribute";
1072 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1073 return emitOpError(
"uni-directional proxies only support generic for "
1074 "from_proxy attribute");
1076 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1077 return emitOpError(
"uni-directional proxies only support tensormap "
1078 "for to_proxy attribute");
1084 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1085 return emitOpError(
"uni-directional proxies only support generic for "
1086 "from_proxy attribute");
1088 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1089 return emitOpError(
"uni-directional proxies only support tensormap "
1090 "for to_proxy attribute");
1096 if (getRegCount() % 8)
1097 return emitOpError(
"new register size must be multiple of 8");
1098 if (getRegCount() < 24 || getRegCount() > 256)
1099 return emitOpError(
"new register size must be in between 24 to 256");
1104 if (getNumberOfThreads() && !getBarrierId())
1106 "barrier id is missing, it should be set between 0 to 15");
1110 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1111 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1113 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1114 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1121 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1122 bool hasCpSize = cpAsyncOp.getCpSize() ? true :
false;
1123 switch (cpAsyncOp.getSize()) {
1131 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1136 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1140 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1141 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1143 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1150 switch (tensorDims) {
1152 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1154 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1157 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1158 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1161 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1162 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1165 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1166 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1168 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1172 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1173 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1175 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1176 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1177 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1179 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1183 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1185 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1187 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1189 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1191 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1193 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1198 int tensorDims, NVVM::TMAReduxKind kind,
bool isIm2Col) {
1199 using RedTy = NVVM::TMAReduxKind;
1218 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1221 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1222 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1223 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1225 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1226 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1227 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1230 NVVM::SaturationMode sat,
1232 using RndMode = NVVM::FPRoundingMode;
1233 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1242 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
1247 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
1250 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1251 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1254 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1258 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1259 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1261 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1262 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1275 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1276 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1277 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1278 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1287 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1288 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1289 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1291 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1292 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1293 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1296 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
1299 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1300 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1303 bool hasMulticast = curOp.getMulticastMask() ? true :
false;
1304 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1312 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
1322 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
1323 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1324 rangeAttr.getLower(), rangeAttr.getUpper()});
1333 void NVVMDialect::initialize() {
1336 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1339 #define GET_ATTRDEF_LIST
1340 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1345 allowUnknownOperations();
1346 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1347 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1350 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1352 StringAttr attrName = attr.
getName();
1354 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1355 if (!isa<LLVM::LLVMFuncOp>(op)) {
1356 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1357 <<
"' attribute attached to unexpected op";
1362 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1363 attrName == NVVMDialect::getReqntidAttrName() ||
1364 attrName == NVVMDialect::getClusterDimAttrName()) {
1365 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1366 if (!values || values.empty() || values.size() > 3)
1369 <<
"' attribute must be integer array with maximum 3 index";
1373 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1374 attrName == NVVMDialect::getMaxnregAttrName() ||
1375 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1376 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1378 <<
"'" << attrName <<
"' attribute must be integer constant";
1384 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1385 unsigned regionIndex,
1388 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1392 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1393 StringAttr attrName = argAttr.
getName();
1394 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1398 <<
"' attribute must be present only on kernel arguments";
1400 if (!isa<UnitAttr>(argAttr.
getValue()))
1401 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1402 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1405 <<
"' attribute requires the argument to also have attribute '"
1406 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1418 int optLevel, StringRef triple, StringRef chip,
1419 StringRef features, DictionaryAttr flags,
1421 if (optLevel < 0 || optLevel > 3) {
1422 emitError() <<
"The optimization level must be a number between 0 and 3.";
1425 if (triple.empty()) {
1426 emitError() <<
"The target triple cannot be empty.";
1430 emitError() <<
"The target chip cannot be empty.";
1434 return attr && mlir::isa<StringAttr>(attr);
1436 emitError() <<
"All the elements in the `link` array must be strings.";
1442 #define GET_OP_CLASSES
1443 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1445 #define GET_ATTRDEF_CLASSES
1446 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.