31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
82 size_t numIm2ColOffsets,
84 if (tensorDims < 1 || tensorDims > 5)
85 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
87 if (numIm2ColOffsets) {
91 "to use im2col mode, the tensor has to be at least 3-dimensional");
92 if (tensorDims != (numIm2ColOffsets + 2))
94 loc,
"im2col offsets must be 2 less than number of coordinates");
101 getIm2colOffsets().size(), getLoc());
106 return emitError(
"Maximum 5 coordinates and dimension is supported.");
111 if (getModifier() != LoadCacheModifierKind::CG &&
112 getModifier() != LoadCacheModifierKind::CA)
113 return emitError(
"Only CG and CA cache modifiers are supported.");
114 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
115 return emitError(
"expected byte size to be either 4, 8 or 16.");
116 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
117 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
123 getIm2colOffsets().size(), getLoc());
129 std::optional<mlir::NVVM::MMATypes>
130 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
133 if (operandElType.
isF64())
134 return NVVM::MMATypes::f64;
135 if (operandElType.
isF16() || operandElType == half2Type)
136 return NVVM::MMATypes::f16;
137 if (operandElType.
isF32() && isAccumulator)
138 return NVVM::MMATypes::f32;
139 if (operandElType.
isF32() && !isAccumulator)
140 return NVVM::MMATypes::tf32;
141 if (llvm::isa<IntegerType>(operandElType)) {
143 return NVVM::MMATypes::s32;
147 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
148 if (structType.getBody().empty())
150 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
157 return (type == MMATypes::u4 || type == MMATypes::s4);
161 return (type == MMATypes::u8 || type == MMATypes::s8);
166 type == MMATypes::s32;
169 MMATypes MmaOp::accumPtxType() {
170 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
171 getODSOperands(2).getTypes().front(),
true);
172 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
176 MMATypes MmaOp::resultPtxType() {
177 std::optional<mlir::NVVM::MMATypes> val =
178 inferOperandMMAType(getResult().
getType(),
true);
179 assert(val.has_value() &&
"result PTX type should always be inferrable");
185 struct OperandFragment {
186 StringRef operandName;
187 StringRef ptxTypeAttr;
189 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
190 : operandName(name), ptxTypeAttr(ptxTypeName) {}
193 std::array<OperandFragment, 3> frags{
194 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
195 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
196 OperandFragment(
"C",
"")};
198 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
200 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
201 auto &frag = frags[fragIdx];
202 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
203 for (
auto operandIdx = varOperandSpec.first;
204 operandIdx < varOperandSpec.first + varOperandSpec.second;
206 frag.regs.push_back(this->getOperand(operandIdx));
207 if (operandIdx == 0) {
208 regTypes.push_back(this->getOperand(operandIdx).
getType());
211 std::optional<MMATypes> inferredType =
212 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
214 ignoreAttrNames.push_back(frag.ptxTypeAttr);
217 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
218 p <<
" " << frag.operandName;
224 for (
const auto &frag : frags) {
225 printMmaOperand(frag);
243 std::optional<MMAIntOverflow> intOverflow,
244 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
245 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
247 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
250 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
256 if (multiplicandPtxTypes) {
262 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
264 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
268 if (multiplicandLayouts) {
278 if (intOverflow.has_value())
281 if (b1Op.has_value())
286 MmaOp::getOperandSegmentSizeAttr(),
288 static_cast<int32_t>(operandB.size()),
289 static_cast<int32_t>(operandC.size())}));
297 struct OperandFragment {
298 std::optional<MMATypes> elemtype;
304 std::array<OperandFragment, 4> frags;
309 auto parseMmaOperand = [&](StringRef operandName,
310 OperandFragment &frag) -> LogicalResult {
321 if (parseMmaOperand(
"A", frags[0]).failed())
323 if (parseMmaOperand(
"B", frags[1]).failed())
325 if (parseMmaOperand(
"C", frags[2]).failed())
340 if (operandTypes.size() != 3)
343 "expected one type for each operand segment but got " +
344 Twine(operandTypes.size()) +
" types");
346 auto &frag = frags[iter.index()];
347 frag.regTypes.resize(frag.regs.size(), iter.value());
352 inferOperandMMAType(frag.regTypes[0], iter.index() < 2);
358 frags[3].elemtype = inferOperandMMAType(resultType,
true);
360 std::array<StringRef, 2> names{
"multiplicandAPtxType",
361 "multiplicandBPtxType"};
362 for (
unsigned idx = 0; idx < names.size(); idx++) {
363 const auto &frag = frags[idx];
364 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
365 if (!frag.elemtype.has_value() && !attr.has_value()) {
368 "attribute " + names[idx] +
369 " is not provided explicitly and cannot be inferred");
371 if (!attr.has_value())
377 if (!namedAttributes.
empty())
381 static_cast<int32_t>(frags[0].regs.size()),
382 static_cast<int32_t>(frags[1].regs.size()),
383 static_cast<int32_t>(frags[2].regs.size()),
395 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
401 auto f16x2x2StructTy =
408 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
409 getShapeAttr().getK()};
415 AllowedShapes allowedShapes;
416 AllowedTypes expectedA;
417 AllowedTypes expectedB;
418 AllowedTypes expectedC;
423 if (mmaShape[0] == 16) {
425 Type multiplicandFragType;
426 switch (*getMultiplicandAPtxType()) {
429 multiplicandFragType = i32Ty;
431 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
436 multiplicandFragType = f16x2Ty;
437 expectedResult.push_back(f16x2x2StructTy);
438 expectedResult.push_back(f32x4StructTy);
452 return emitError(
"invalid shape or multiplicand type: " +
453 stringifyEnum(getMultiplicandAPtxType().value()));
457 expectedResult.push_back(s32x4StructTy);
458 expectedC.emplace_back(4, i32Ty);
459 multiplicandFragType = i32Ty;
461 expectedC.emplace_back(2, f16x2Ty);
462 expectedC.emplace_back(4, f32Ty);
465 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
466 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
467 expectedA.emplace_back(unitA, multiplicandFragType);
468 expectedB.emplace_back(unitB, multiplicandFragType);
469 allowedShapes.push_back({16, 8, kFactor});
470 allowedShapes.push_back({16, 8, kFactor * 2});
474 if (mmaShape[0] == 8) {
475 if (*getMultiplicandAPtxType() == MMATypes::f16) {
476 expectedA.emplace_back(2, f16x2Ty);
477 expectedB.emplace_back(2, f16x2Ty);
478 expectedResult.push_back(f16x2x4StructTy);
479 expectedResult.push_back(f32x8StructTy);
480 expectedC.emplace_back(4, f16x2Ty);
481 expectedC.emplace_back(8, f32Ty);
482 allowedShapes.push_back({8, 8, 4});
484 if (*getMultiplicandAPtxType() == MMATypes::f64) {
486 expectedA.emplace_back(1, f64Ty);
487 expectedB.emplace_back(1, f64Ty);
488 expectedC.emplace_back(2, f64Ty);
492 allowedShapes.push_back({8, 8, 4});
495 expectedA.push_back({i32Ty});
496 expectedB.push_back({i32Ty});
497 expectedC.push_back({i32Ty, i32Ty});
498 expectedResult.push_back(s32x2StructTy);
500 allowedShapes.push_back({8, 8, 32});
502 allowedShapes.push_back({8, 8, 16});
503 if (getMultiplicandAPtxType().value() == MMATypes::b1)
504 allowedShapes.push_back({8, 8, 128});
508 std::string errorMessage;
509 llvm::raw_string_ostream errorStream(errorMessage);
512 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
513 !llvm::is_contained(allowedShapes, mmaShape)) {
514 errorStream <<
"unimplemented variant for MMA shape <";
515 llvm::interleaveComma(mmaShape, errorStream);
517 return emitOpError(errorMessage);
521 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
524 auto spec = this->getODSOperandIndexAndLength(iter.index());
526 operand_type_begin() + spec.first +
528 bool match = llvm::is_contained(iter.value(), operandTySeg);
531 errorStream <<
"Could not match types for the "
532 << operandNames[iter.index()]
533 <<
" operands; expected one of ";
534 for (
const auto &x : iter.value()) {
535 errorStream << x.size() <<
"x" << x[0] <<
" ";
537 errorStream <<
"but got ";
538 llvm::interleaveComma(operandTySeg, errorStream);
539 return emitOpError(errorMessage);
544 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
545 return expectedResultType == getResult().getType();
548 <<
"Could not match allowed types for the result; expected one of ";
549 llvm::interleaveComma(expectedResult, errorStream);
550 errorStream <<
" but got " << getResult().getType();
551 return emitOpError(errorMessage);
555 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
556 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
564 if (!getIntOverflowBehavior())
565 return emitOpError(
"op requires " +
566 getIntOverflowBehaviorAttrName().strref() +
574 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
576 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
577 auto elementType = (type && type.getBody().size() == 2)
578 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
580 if (!elementType || elementType.getWidth() != 1)
581 return emitError(
"expected return type to be a two-element struct with "
582 "i1 as the second element");
587 NVVM::MMAFrag frag,
int nRow,
590 unsigned numberElements = 0;
594 if (type == NVVM::MMATypes::f16) {
596 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
600 }
else if (type == NVVM::MMATypes::f32) {
603 }
else if (type == NVVM::MMATypes::tf32) {
606 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
608 int parallelSize = 0;
609 if (frag == NVVM::MMAFrag::a)
611 if (frag == NVVM::MMAFrag::b)
615 if (parallelSize == 16)
618 else if (parallelSize == 8)
620 else if (parallelSize == 32)
622 }
else if (type == NVVM::MMATypes::s32) {
626 assert(numberElements != 0 && elementType !=
nullptr);
627 return std::make_pair(elementType, numberElements);
630 static std::pair<mlir::Type, unsigned>
634 if (frag == NVVM::MMAFrag::a) {
637 }
else if (frag == NVVM::MMAFrag::b) {
644 assert(nRow && nCol);
649 unsigned addressSpace =
650 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
653 return emitOpError(
"expected source pointer in memory "
656 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
657 getEltype(), getFrag()) == 0)
658 return emitOpError() <<
"invalid attribute combination";
664 return emitOpError(
"expected destination type is a structure of ")
665 << typeInfo.second <<
" elements of type " << typeInfo.first;
670 unsigned addressSpace =
671 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
674 return emitOpError(
"expected operands to be a source pointer in memory "
677 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
679 return emitOpError() <<
"invalid attribute combination";
682 if (getArgs().size() != typeInfo.second)
683 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
684 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
685 return operands.
getType() != typeInfo.first;
687 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
692 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
693 getLayoutB(), getEltypeA(),
695 return emitOpError() <<
"invalid attribute combination";
703 arguments.append(typeInfoA.second, typeInfoA.first);
704 arguments.append(typeInfoB.second, typeInfoB.first);
705 arguments.append(typeInfoC.second, typeInfoC.first);
706 unsigned numArgs = arguments.size();
707 if (getArgs().size() != numArgs)
708 return emitOpError() <<
"expected " << numArgs <<
" arguments";
709 for (
unsigned i = 0; i < numArgs; i++) {
710 if (getArgs()[i].
getType() != arguments[i])
711 return emitOpError() <<
"expected argument " << i <<
" to be of type "
717 return emitOpError(
"expected destination type is a structure of ")
718 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
723 unsigned addressSpace =
724 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
726 return emitOpError(
"expected source pointer in memory space 3");
728 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
729 return emitOpError(
"expected num attribute to be 1, 2 or 4");
732 if (getNum() == 1 &&
getType() != i32)
733 return emitOpError(
"expected destination type is i32");
734 if (getNum() == 2 || getNum() == 4) {
738 return emitOpError(
"expected destination type is a structure of ")
739 << getNum() <<
" elements of type i32";
745 unsigned addressSpace =
746 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
748 return emitOpError(
"expected source pointer in memory space 3");
750 int numMatrix = getSources().size();
751 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
752 return emitOpError(
"expected num attribute to be 1, 2 or 4");
758 if (typeA == NVVM::WGMMATypes::tf32)
760 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
762 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
764 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
766 if (typeA == NVVM::WGMMATypes::b1)
772 NVVM::WGMMATypes typeA,
773 NVVM::WGMMATypes typeB) {
775 case NVVM::WGMMATypes::f16:
776 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
777 typeB == NVVM::WGMMATypes::f16)
780 case NVVM::WGMMATypes::tf32:
781 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
784 case NVVM::WGMMATypes::u8:
785 case NVVM::WGMMATypes::s8:
786 if (typeD == NVVM::WGMMATypes::s32 &&
787 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
790 case NVVM::WGMMATypes::b1:
791 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
794 case NVVM::WGMMATypes::bf16:
795 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
796 typeB == NVVM::WGMMATypes::bf16)
799 case NVVM::WGMMATypes::e4m3:
800 case NVVM::WGMMATypes::e5m2:
801 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
802 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
805 case WGMMATypes::f32:
806 case WGMMATypes::s32:
807 llvm_unreachable(
"unsupported input types");
815 72, 80, 88, 96, 104, 112, 120, 128,
816 136, 144, 152, 160, 168, 176, 184, 192,
817 200, 208, 216, 224, 232, 240, 248, 256};
819 80, 96, 112, 128, 144, 160,
820 176, 192, 208, 224, 240, 256};
822 case WGMMATypes::f16:
823 case WGMMATypes::tf32:
824 case WGMMATypes::bf16:
825 case WGMMATypes::e4m3:
826 case WGMMATypes::e5m2:
827 if (llvm::is_contained(allowedN, sizeN))
833 if (llvm::is_contained(allowedNshort, sizeN))
836 case WGMMATypes::f32:
837 case WGMMATypes::s32:
838 llvm_unreachable(
"unsupported input types");
845 Value outValue = getResults();
846 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
848 return emitOpError() <<
"expected results to be struct";
849 int outputSize = stype.getBody().size();
850 WGMMATypes typeD = getTypeD();
851 WGMMATypes typeA = getTypeA();
852 WGMMATypes typeB = getTypeB();
854 for (
Type t : stype.getBody()) {
855 if (t != stype.getBody().front())
857 <<
"all elements in struct must be same type but there is " << t;
860 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
861 typeD != WGMMATypes::s32) {
862 return emitOpError() <<
"does not support the given output type "
863 << NVVM::stringifyWGMMATypes(typeD);
865 if (typeD == WGMMATypes::s32 &&
866 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
867 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
871 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
872 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
873 << NVVM::stringifyWGMMATypes(typeB)
874 <<
", it is not supported.";
879 return emitOpError() <<
"shape 'm' must be 64";
883 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
884 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
885 <<
" for input type "
886 << NVVM::stringifyWGMMATypes(typeA);
890 return emitOpError() <<
"has input type "
891 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
892 <<
getShape().getN() <<
", it is not supported.";
899 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
900 (getLayoutA() == mlir::NVVM::MMALayout::col ||
901 getLayoutB() == mlir::NVVM::MMALayout::row)) {
903 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
904 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
905 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
906 << stringifyWGMMATypes(typeB)
907 <<
" requires transpose. However, this is only supported for: "
908 << stringifyMMATypes(MMATypes::f16) <<
" and "
909 << stringifyMMATypes(MMATypes::bf16);
913 int expectedOutput = 0;
914 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
915 expectedOutput =
getShape().getN() / 2;
916 if (typeD == WGMMATypes::f16)
917 expectedOutput =
getShape().getN() / 4;
918 if (outputSize != expectedOutput) {
919 return emitOpError() <<
"results " << expectedOutput
920 <<
", however output struct has " << outputSize
924 if (typeD != WGMMATypes::s32 &&
925 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
926 NVVM::MMAIntOverflow::satfinite) {
928 <<
" `satfinite` can be only used with s32 accumulator, however "
929 "the current accumulator is "
930 << NVVM::stringifyWGMMATypes(typeD);
936 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
939 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
941 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
943 int expectedOutputRegisters = 0;
944 if (getTypeD() == WGMMATypes::f16)
945 expectedOutputRegisters =
getShape().getN() / 4;
947 expectedOutputRegisters =
getShape().getN() / 2;
950 llvm::raw_string_ostream ss(ptx);
955 << ((expectedOutputRegisters * 2) + 2)
957 "wgmma.mma_async.sync.aligned.m"
958 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
959 << stringifyWGMMATypes(getTypeA()) <<
"."
960 << stringifyWGMMATypes(getTypeB());
961 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
962 NVVM::MMAIntOverflow::satfinite)
966 for (; regCnt < expectedOutputRegisters; ++regCnt) {
968 if (regCnt != expectedOutputRegisters - 1)
974 regCnt = (regCnt * 2);
975 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
976 if (getTypeD() != WGMMATypes::s32) {
977 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
981 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
988 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
992 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
999 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1001 if (getTypeD() != WGMMATypes::s32) {
1002 asmValues.push_back(
1003 {makeConstantI32(rewriter,
1004 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1006 asmValues.push_back(
1007 {makeConstantI32(rewriter,
1008 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1012 asmValues.push_back(
1013 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1015 asmValues.push_back(
1016 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1021 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1022 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1023 if (getKind() == NVVM::ProxyKind::GENERIC)
1024 return emitOpError() <<
"generic proxy not a supported proxy kind";
1025 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1026 return emitOpError() <<
"async_shared fence requires space attribute";
1028 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1029 return emitOpError() <<
"only async_shared fence can have space attribute";
1035 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1036 return emitOpError(
"uni-directional proxies only support generic for "
1037 "from_proxy attribute");
1039 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1040 return emitOpError(
"uni-directional proxies only support tensormap "
1041 "for to_proxy attribute");
1047 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1048 return emitOpError(
"uni-directional proxies only support generic for "
1049 "from_proxy attribute");
1051 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1052 return emitOpError(
"uni-directional proxies only support tensormap "
1053 "for to_proxy attribute");
1059 if (getRegCount() % 8)
1060 return emitOpError(
"new register size must be multiple of 8");
1061 if (getRegCount() < 24 || getRegCount() > 256)
1062 return emitOpError(
"new register size must be in between 24 to 256");
1067 if (getNumberOfThreads() && !getBarrierId())
1069 "barrier id is missing, it should be set between 0 to 15");
1075 switch (tensorDims) {
1077 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1079 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1082 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1083 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1086 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1087 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1090 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1091 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1093 llvm_unreachable(
"Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1102 void NVVMDialect::initialize() {
1105 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1108 #define GET_ATTRDEF_LIST
1109 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1114 allowUnknownOperations();
1115 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1116 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1119 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1121 StringAttr attrName = attr.
getName();
1123 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1124 if (!isa<LLVM::LLVMFuncOp>(op)) {
1125 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1126 <<
"' attribute attached to unexpected op";
1130 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1131 attrName == NVVMDialect::getReqntidAttrName()) {
1132 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1133 if (!values || values.empty() || values.size() > 3)
1136 <<
"' attribute must be integer array with maximum 3 index";
1139 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1140 attrName == NVVMDialect::getMaxnregAttrName()) {
1141 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1143 <<
"'" << attrName <<
"' attribute must be integer constant";
1149 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1150 unsigned regionIndex,
1153 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1157 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1158 StringAttr attrName = argAttr.
getName();
1159 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1163 <<
"' attribute must be present only on kernel arguments";
1165 if (!isa<UnitAttr>(argAttr.
getValue()))
1166 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1167 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1170 <<
"' attribute requires the argument to also have attribute '"
1171 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1183 int optLevel, StringRef triple, StringRef chip,
1184 StringRef features, DictionaryAttr flags,
1186 if (optLevel < 0 || optLevel > 3) {
1187 emitError() <<
"The optimization level must be a number between 0 and 3.";
1190 if (triple.empty()) {
1191 emitError() <<
"The target triple cannot be empty.";
1195 emitError() <<
"The target chip cannot be empty.";
1199 return attr && mlir::isa<StringAttr>(attr);
1201 emitError() <<
"All the elements in the `link` array must be strings.";
1207 #define GET_OP_CLASSES
1208 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1210 #define GET_ATTRDEF_CLASSES
1211 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, size_t numIm2ColOffsets, Location loc)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.