31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
84 size_t numIm2ColOffsets,
86 if (tensorDims < 1 || tensorDims > 5)
87 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
95 "to use im2col mode, the tensor has to be at least 3-dimensional");
97 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
99 loc,
"im2col offsets must be 2 less than number of coordinates");
105 size_t numIm2ColOffsets = getIm2colOffsets().size();
106 bool isIm2Col = numIm2ColOffsets > 0;
108 numIm2ColOffsets, getLoc());
113 return emitError(
"Maximum 5 coordinates and dimension is supported.");
118 if (getModifier() != LoadCacheModifierKind::CG &&
119 getModifier() != LoadCacheModifierKind::CA)
120 return emitError(
"Only CG and CA cache modifiers are supported.");
121 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122 return emitError(
"expected byte size to be either 4, 8 or 16.");
123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
129 size_t numIm2ColOffsets = getIm2colOffsets().size();
130 bool isIm2Col = numIm2ColOffsets > 0;
132 numIm2ColOffsets, getLoc());
136 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
144 std::optional<mlir::NVVM::MMATypes>
145 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
148 if (operandElType.
isF64())
149 return NVVM::MMATypes::f64;
150 if (operandElType.
isF16() || operandElType == half2Type)
151 return NVVM::MMATypes::f16;
152 if (operandElType.
isF32() && isAccumulator)
153 return NVVM::MMATypes::f32;
154 if (operandElType.
isF32() && !isAccumulator)
155 return NVVM::MMATypes::tf32;
156 if (llvm::isa<IntegerType>(operandElType)) {
158 return NVVM::MMATypes::s32;
162 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
163 if (structType.getBody().empty())
165 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
172 return (type == MMATypes::u4 || type == MMATypes::s4);
176 return (type == MMATypes::u8 || type == MMATypes::s8);
181 type == MMATypes::s32;
184 MMATypes MmaOp::accumPtxType() {
185 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
186 getODSOperands(2).getTypes().front(),
true);
187 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
191 MMATypes MmaOp::resultPtxType() {
192 std::optional<mlir::NVVM::MMATypes> val =
193 inferOperandMMAType(getResult().
getType(),
true);
194 assert(val.has_value() &&
"result PTX type should always be inferrable");
200 struct OperandFragment {
201 StringRef operandName;
202 StringRef ptxTypeAttr;
204 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
205 : operandName(name), ptxTypeAttr(ptxTypeName) {}
208 std::array<OperandFragment, 3> frags{
209 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
210 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
211 OperandFragment(
"C",
"")};
213 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
215 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
216 auto &frag = frags[fragIdx];
217 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
218 for (
auto operandIdx = varOperandSpec.first;
219 operandIdx < varOperandSpec.first + varOperandSpec.second;
221 frag.regs.push_back(this->getOperand(operandIdx));
222 if (operandIdx == 0) {
223 regTypes.push_back(this->getOperand(operandIdx).
getType());
226 std::optional<MMATypes> inferredType =
227 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
229 ignoreAttrNames.push_back(frag.ptxTypeAttr);
232 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
233 p <<
" " << frag.operandName;
239 for (
const auto &frag : frags) {
240 printMmaOperand(frag);
258 std::optional<MMAIntOverflow> intOverflow,
259 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
260 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
262 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
265 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
271 if (multiplicandPtxTypes) {
277 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
279 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
283 if (multiplicandLayouts) {
293 if (intOverflow.has_value())
296 if (b1Op.has_value())
301 MmaOp::getOperandSegmentSizeAttr(),
303 static_cast<int32_t>(operandB.size()),
304 static_cast<int32_t>(operandC.size())}));
312 struct OperandFragment {
313 std::optional<MMATypes> elemtype;
319 std::array<OperandFragment, 4> frags;
324 auto parseMmaOperand = [&](StringRef operandName,
325 OperandFragment &frag) -> LogicalResult {
336 if (parseMmaOperand(
"A", frags[0]).failed())
338 if (parseMmaOperand(
"B", frags[1]).failed())
340 if (parseMmaOperand(
"C", frags[2]).failed())
355 if (operandTypes.size() != 3)
358 "expected one type for each operand segment but got " +
359 Twine(operandTypes.size()) +
" types");
361 auto &frag = frags[iter.index()];
362 frag.regTypes.resize(frag.regs.size(), iter.value());
367 inferOperandMMAType(frag.regTypes[0], iter.index() < 2);
373 frags[3].elemtype = inferOperandMMAType(resultType,
true);
375 std::array<StringRef, 2> names{
"multiplicandAPtxType",
376 "multiplicandBPtxType"};
377 for (
unsigned idx = 0; idx < names.size(); idx++) {
378 const auto &frag = frags[idx];
379 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
380 if (!frag.elemtype.has_value() && !attr.has_value()) {
383 "attribute " + names[idx] +
384 " is not provided explicitly and cannot be inferred");
386 if (!attr.has_value())
392 if (!namedAttributes.
empty())
396 static_cast<int32_t>(frags[0].regs.size()),
397 static_cast<int32_t>(frags[1].regs.size()),
398 static_cast<int32_t>(frags[2].regs.size()),
409 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
410 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
413 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
416 auto f16x2x2StructTy =
417 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
419 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
421 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
423 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
424 getShapeAttr().getK()};
430 AllowedShapes allowedShapes;
431 AllowedTypes expectedA;
432 AllowedTypes expectedB;
433 AllowedTypes expectedC;
438 if (mmaShape[0] == 16) {
440 Type multiplicandFragType;
441 switch (*getMultiplicandAPtxType()) {
444 multiplicandFragType = i32Ty;
445 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
446 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
451 multiplicandFragType = f16x2Ty;
452 expectedResult.push_back(f16x2x2StructTy);
453 expectedResult.push_back(f32x4StructTy);
467 return emitError(
"invalid shape or multiplicand type: " +
468 stringifyEnum(getMultiplicandAPtxType().value()));
472 expectedResult.push_back(s32x4StructTy);
473 expectedC.emplace_back(4, i32Ty);
474 multiplicandFragType = i32Ty;
476 expectedC.emplace_back(2, f16x2Ty);
477 expectedC.emplace_back(4, f32Ty);
480 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
481 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
482 expectedA.emplace_back(unitA, multiplicandFragType);
483 expectedB.emplace_back(unitB, multiplicandFragType);
484 allowedShapes.push_back({16, 8, kFactor});
485 allowedShapes.push_back({16, 8, kFactor * 2});
489 if (mmaShape[0] == 8) {
490 if (*getMultiplicandAPtxType() == MMATypes::f16) {
491 expectedA.emplace_back(2, f16x2Ty);
492 expectedB.emplace_back(2, f16x2Ty);
493 expectedResult.push_back(f16x2x4StructTy);
494 expectedResult.push_back(f32x8StructTy);
495 expectedC.emplace_back(4, f16x2Ty);
496 expectedC.emplace_back(8, f32Ty);
497 allowedShapes.push_back({8, 8, 4});
499 if (*getMultiplicandAPtxType() == MMATypes::f64) {
501 expectedA.emplace_back(1, f64Ty);
502 expectedB.emplace_back(1, f64Ty);
503 expectedC.emplace_back(2, f64Ty);
505 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
507 allowedShapes.push_back({8, 8, 4});
510 expectedA.push_back({i32Ty});
511 expectedB.push_back({i32Ty});
512 expectedC.push_back({i32Ty, i32Ty});
513 expectedResult.push_back(s32x2StructTy);
515 allowedShapes.push_back({8, 8, 32});
517 allowedShapes.push_back({8, 8, 16});
518 if (getMultiplicandAPtxType().value() == MMATypes::b1)
519 allowedShapes.push_back({8, 8, 128});
523 std::string errorMessage;
524 llvm::raw_string_ostream errorStream(errorMessage);
527 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
528 !llvm::is_contained(allowedShapes, mmaShape)) {
529 errorStream <<
"unimplemented variant for MMA shape <";
530 llvm::interleaveComma(mmaShape, errorStream);
532 return emitOpError(errorMessage);
536 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
539 auto spec = this->getODSOperandIndexAndLength(iter.index());
541 operand_type_begin() + spec.first +
543 bool match = llvm::is_contained(iter.value(), operandTySeg);
546 errorStream <<
"Could not match types for the "
547 << operandNames[iter.index()]
548 <<
" operands; expected one of ";
549 for (
const auto &x : iter.value()) {
550 errorStream << x.size() <<
"x" << x[0] <<
" ";
552 errorStream <<
"but got ";
553 llvm::interleaveComma(operandTySeg, errorStream);
554 return emitOpError(errorMessage);
559 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
560 return expectedResultType == getResult().getType();
563 <<
"Could not match allowed types for the result; expected one of ";
564 llvm::interleaveComma(expectedResult, errorStream);
565 errorStream <<
" but got " << getResult().getType();
566 return emitOpError(errorMessage);
570 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
571 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
579 if (!getIntOverflowBehavior())
580 return emitOpError(
"op requires " +
581 getIntOverflowBehaviorAttrName().strref() +
589 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
591 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
592 auto elementType = (type && type.getBody().size() == 2)
593 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
595 if (!elementType || elementType.getWidth() != 1)
596 return emitError(
"expected return type to be a two-element struct with "
597 "i1 as the second element");
602 NVVM::MMAFrag frag,
int nRow,
605 unsigned numberElements = 0;
609 if (type == NVVM::MMATypes::f16) {
611 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
615 }
else if (type == NVVM::MMATypes::f32) {
618 }
else if (type == NVVM::MMATypes::tf32) {
621 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
623 int parallelSize = 0;
624 if (frag == NVVM::MMAFrag::a)
626 if (frag == NVVM::MMAFrag::b)
630 if (parallelSize == 16)
633 else if (parallelSize == 8)
635 else if (parallelSize == 32)
637 }
else if (type == NVVM::MMATypes::s32) {
641 assert(numberElements != 0 && elementType !=
nullptr);
642 return std::make_pair(elementType, numberElements);
645 static std::pair<mlir::Type, unsigned>
649 if (frag == NVVM::MMAFrag::a) {
652 }
else if (frag == NVVM::MMAFrag::b) {
659 assert(nRow && nCol);
664 unsigned addressSpace =
665 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
668 return emitOpError(
"expected source pointer in memory "
671 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
672 getEltype(), getFrag()) == 0)
673 return emitOpError() <<
"invalid attribute combination";
676 Type dstType = LLVM::LLVMStructType::getLiteral(
679 return emitOpError(
"expected destination type is a structure of ")
680 << typeInfo.second <<
" elements of type " << typeInfo.first;
685 unsigned addressSpace =
686 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
689 return emitOpError(
"expected operands to be a source pointer in memory "
692 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
694 return emitOpError() <<
"invalid attribute combination";
697 if (getArgs().size() != typeInfo.second)
698 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
699 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
700 return operands.
getType() != typeInfo.first;
702 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
707 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
708 getLayoutB(), getEltypeA(),
710 return emitOpError() <<
"invalid attribute combination";
718 arguments.append(typeInfoA.second, typeInfoA.first);
719 arguments.append(typeInfoB.second, typeInfoB.first);
720 arguments.append(typeInfoC.second, typeInfoC.first);
721 unsigned numArgs = arguments.size();
722 if (getArgs().size() != numArgs)
723 return emitOpError() <<
"expected " << numArgs <<
" arguments";
724 for (
unsigned i = 0; i < numArgs; i++) {
725 if (getArgs()[i].
getType() != arguments[i])
726 return emitOpError() <<
"expected argument " << i <<
" to be of type "
729 Type dstType = LLVM::LLVMStructType::getLiteral(
732 return emitOpError(
"expected destination type is a structure of ")
733 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
738 unsigned addressSpace =
739 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
741 return emitOpError(
"expected source pointer in memory space 3");
743 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
744 return emitOpError(
"expected num attribute to be 1, 2 or 4");
747 if (getNum() == 1 &&
getType() != i32)
748 return emitOpError(
"expected destination type is i32");
749 if (getNum() == 2 || getNum() == 4) {
750 Type dstType = LLVM::LLVMStructType::getLiteral(
753 return emitOpError(
"expected destination type is a structure of ")
754 << getNum() <<
" elements of type i32";
760 unsigned addressSpace =
761 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
763 return emitOpError(
"expected source pointer in memory space 3");
765 int numMatrix = getSources().size();
766 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
767 return emitOpError(
"expected num attribute to be 1, 2 or 4");
773 if (typeA == NVVM::WGMMATypes::tf32)
775 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
777 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
779 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
781 if (typeA == NVVM::WGMMATypes::b1)
787 NVVM::WGMMATypes typeA,
788 NVVM::WGMMATypes typeB) {
790 case NVVM::WGMMATypes::f16:
791 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
792 typeB == NVVM::WGMMATypes::f16)
795 case NVVM::WGMMATypes::tf32:
796 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
799 case NVVM::WGMMATypes::u8:
800 case NVVM::WGMMATypes::s8:
801 if (typeD == NVVM::WGMMATypes::s32 &&
802 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
805 case NVVM::WGMMATypes::b1:
806 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
809 case NVVM::WGMMATypes::bf16:
810 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
811 typeB == NVVM::WGMMATypes::bf16)
814 case NVVM::WGMMATypes::e4m3:
815 case NVVM::WGMMATypes::e5m2:
816 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
817 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
820 case WGMMATypes::f32:
821 case WGMMATypes::s32:
822 llvm_unreachable(
"unsupported input types");
830 72, 80, 88, 96, 104, 112, 120, 128,
831 136, 144, 152, 160, 168, 176, 184, 192,
832 200, 208, 216, 224, 232, 240, 248, 256};
834 80, 96, 112, 128, 144, 160,
835 176, 192, 208, 224, 240, 256};
837 case WGMMATypes::f16:
838 case WGMMATypes::tf32:
839 case WGMMATypes::bf16:
840 case WGMMATypes::e4m3:
841 case WGMMATypes::e5m2:
842 if (llvm::is_contained(allowedN, sizeN))
848 if (llvm::is_contained(allowedNshort, sizeN))
851 case WGMMATypes::f32:
852 case WGMMATypes::s32:
853 llvm_unreachable(
"unsupported input types");
860 Value outValue = getResults();
861 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
863 return emitOpError() <<
"expected results to be struct";
864 int outputSize = stype.getBody().size();
865 WGMMATypes typeD = getTypeD();
866 WGMMATypes typeA = getTypeA();
867 WGMMATypes typeB = getTypeB();
869 for (
Type t : stype.getBody()) {
870 if (t != stype.getBody().front())
872 <<
"all elements in struct must be same type but there is " << t;
875 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
876 typeD != WGMMATypes::s32) {
877 return emitOpError() <<
"does not support the given output type "
878 << NVVM::stringifyWGMMATypes(typeD);
880 if (typeD == WGMMATypes::s32 &&
881 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
882 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
886 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
887 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
888 << NVVM::stringifyWGMMATypes(typeB)
889 <<
", it is not supported.";
894 return emitOpError() <<
"shape 'm' must be 64";
898 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
899 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
900 <<
" for input type "
901 << NVVM::stringifyWGMMATypes(typeA);
905 return emitOpError() <<
"has input type "
906 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
907 <<
getShape().getN() <<
", it is not supported.";
914 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
915 (getLayoutA() == mlir::NVVM::MMALayout::col ||
916 getLayoutB() == mlir::NVVM::MMALayout::row)) {
918 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
919 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
920 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
921 << stringifyWGMMATypes(typeB)
922 <<
" requires transpose. However, this is only supported for: "
923 << stringifyMMATypes(MMATypes::f16) <<
" and "
924 << stringifyMMATypes(MMATypes::bf16);
928 int expectedOutput = 0;
929 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
930 expectedOutput =
getShape().getN() / 2;
931 if (typeD == WGMMATypes::f16)
932 expectedOutput =
getShape().getN() / 4;
933 if (outputSize != expectedOutput) {
934 return emitOpError() <<
"results " << expectedOutput
935 <<
", however output struct has " << outputSize
939 if (typeD != WGMMATypes::s32 &&
940 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
941 NVVM::MMAIntOverflow::satfinite) {
943 <<
" `satfinite` can be only used with s32 accumulator, however "
944 "the current accumulator is "
945 << NVVM::stringifyWGMMATypes(typeD);
951 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
954 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
956 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
958 int expectedOutputRegisters = 0;
959 if (getTypeD() == WGMMATypes::f16)
960 expectedOutputRegisters =
getShape().getN() / 4;
962 expectedOutputRegisters =
getShape().getN() / 2;
965 llvm::raw_string_ostream ss(ptx);
970 << ((expectedOutputRegisters * 2) + 2)
972 "wgmma.mma_async.sync.aligned.m"
973 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
974 << stringifyWGMMATypes(getTypeA()) <<
"."
975 << stringifyWGMMATypes(getTypeB());
976 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
977 NVVM::MMAIntOverflow::satfinite)
981 for (; regCnt < expectedOutputRegisters; ++regCnt) {
983 if (regCnt != expectedOutputRegisters - 1)
989 regCnt = (regCnt * 2);
990 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
991 if (getTypeD() != WGMMATypes::s32) {
992 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
996 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1003 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1007 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1014 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1016 if (getTypeD() != WGMMATypes::s32) {
1017 asmValues.push_back(
1018 {makeConstantI32(rewriter,
1019 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1021 asmValues.push_back(
1022 {makeConstantI32(rewriter,
1023 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1027 asmValues.push_back(
1028 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1030 asmValues.push_back(
1031 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1036 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1037 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1038 if (getKind() == NVVM::ProxyKind::GENERIC)
1039 return emitOpError() <<
"generic proxy not a supported proxy kind";
1040 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1041 return emitOpError() <<
"async_shared fence requires space attribute";
1043 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1044 return emitOpError() <<
"only async_shared fence can have space attribute";
1050 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1051 return emitOpError(
"uni-directional proxies only support generic for "
1052 "from_proxy attribute");
1054 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1055 return emitOpError(
"uni-directional proxies only support tensormap "
1056 "for to_proxy attribute");
1062 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1063 return emitOpError(
"uni-directional proxies only support generic for "
1064 "from_proxy attribute");
1066 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1067 return emitOpError(
"uni-directional proxies only support tensormap "
1068 "for to_proxy attribute");
1074 if (getRegCount() % 8)
1075 return emitOpError(
"new register size must be multiple of 8");
1076 if (getRegCount() < 24 || getRegCount() > 256)
1077 return emitOpError(
"new register size must be in between 24 to 256");
1082 if (getNumberOfThreads() && !getBarrierId())
1084 "barrier id is missing, it should be set between 0 to 15");
1090 switch (tensorDims) {
1092 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1094 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1097 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1098 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1101 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1102 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1105 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1106 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1108 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1112 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1115 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1116 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1117 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1119 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1123 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1125 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1127 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1129 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1131 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1133 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1138 int tensorDims, NVVM::TMAReduxKind kind,
bool isIm2Col) {
1139 using RedTy = NVVM::TMAReduxKind;
1158 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1166 void NVVMDialect::initialize() {
1169 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1172 #define GET_ATTRDEF_LIST
1173 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1178 allowUnknownOperations();
1179 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1180 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1183 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1185 StringAttr attrName = attr.
getName();
1187 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1188 if (!isa<LLVM::LLVMFuncOp>(op)) {
1189 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1190 <<
"' attribute attached to unexpected op";
1195 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1196 attrName == NVVMDialect::getReqntidAttrName() ||
1197 attrName == NVVMDialect::getClusterDimAttrName()) {
1198 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1199 if (!values || values.empty() || values.size() > 3)
1202 <<
"' attribute must be integer array with maximum 3 index";
1206 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1207 attrName == NVVMDialect::getMaxnregAttrName() ||
1208 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1209 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1211 <<
"'" << attrName <<
"' attribute must be integer constant";
1217 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1218 unsigned regionIndex,
1221 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1225 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1226 StringAttr attrName = argAttr.
getName();
1227 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1231 <<
"' attribute must be present only on kernel arguments";
1233 if (!isa<UnitAttr>(argAttr.
getValue()))
1234 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1235 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1238 <<
"' attribute requires the argument to also have attribute '"
1239 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1251 int optLevel, StringRef triple, StringRef chip,
1252 StringRef features, DictionaryAttr flags,
1254 if (optLevel < 0 || optLevel > 3) {
1255 emitError() <<
"The optimization level must be a number between 0 and 3.";
1258 if (triple.empty()) {
1259 emitError() <<
"The target triple cannot be empty.";
1263 emitError() <<
"The target chip cannot be empty.";
1267 return attr && mlir::isa<StringAttr>(attr);
1269 emitError() <<
"All the elements in the `link` array must be strings.";
1275 #define GET_OP_CLASSES
1276 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1278 #define GET_ATTRDEF_CLASSES
1279 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.