31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
57 size_t numIm2ColOffsets,
59 if (tensorDims < 1 || tensorDims > 5)
60 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
68 "to use im2col mode, the tensor has to be at least 3-dimensional");
70 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
72 loc,
"im2col offsets must be 2 less than number of coordinates");
78 size_t numIm2ColOffsets = getIm2colOffsets().size();
79 bool isIm2Col = numIm2ColOffsets > 0;
81 numIm2ColOffsets, getLoc());
85 TMAStoreMode mode = getMode();
90 if (mode != TMAStoreMode::TILE)
91 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
93 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
98 case TMAStoreMode::TILE:
100 case TMAStoreMode::IM2COL:
102 case TMAStoreMode::TILE_SCATTER4:
104 return emitError(
"Scatter4 mode expects 5 coordinates");
110 if (getModifier() != LoadCacheModifierKind::CG &&
111 getModifier() != LoadCacheModifierKind::CA)
112 return emitError(
"Only CG and CA cache modifiers are supported.");
113 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
114 return emitError(
"expected byte size to be either 4, 8 or 16.");
115 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
116 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
123 if (tensorDims < 1 || tensorDims > 5)
124 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
126 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
127 size_t expectedIm2colOff) -> LogicalResult {
128 if (isIm2col && (tensorDims < 3))
130 <<
"to use " << stringifyEnum(mode)
131 <<
" mode, the tensor has to be at least 3-dimensional";
133 if (numIm2colOff != expectedIm2colOff)
134 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
135 <<
" (provided " << numIm2colOff <<
")";
141 case TMALoadMode::TILE:
142 return checkTMALoadParams(mode,
false, 0);
143 case TMALoadMode::IM2COL:
144 return checkTMALoadParams(mode,
true, tensorDims - 2);
145 case TMALoadMode::IM2COL_W:
146 case TMALoadMode::IM2COL_W_128:
147 return checkTMALoadParams(mode,
true, 2);
148 case TMALoadMode::TILE_GATHER4:
149 return (tensorDims == 5)
150 ? checkTMALoadParams(mode,
false, 0)
151 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
158 getMode(), getLoc());
162 TMAStoreMode mode = getMode();
165 case TMAStoreMode::TILE:
167 case TMAStoreMode::IM2COL:
169 case TMAStoreMode::TILE_SCATTER4:
170 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
176 using RndMode = NVVM::FPRoundingMode;
180 return emitError(
"Relu not supported with rna rounding mode.");
187 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
193 using RndMode = NVVM::FPRoundingMode;
194 using SatMode = NVVM::SaturationMode;
196 bool isRoundingModeRN = getRnd() == RndMode::RN;
197 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
198 bool isRoundingModeRP = getRnd() == RndMode::RP;
199 bool isSatFinite = getSat() == SatMode::SATFINITE;
201 bool hasRelu = getRelu();
204 case ConvertFP8Type::E4M3:
205 case ConvertFP8Type::E5M2:
206 if (!isRoundingModeRN)
207 return emitOpError(
"Only RN rounding mode is supported for conversions "
208 "from f32x2 to .e4m3x2 or .e5m2x2 types");
210 return emitOpError(
"Only SATFINITE saturation mode is supported for "
211 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
213 case ConvertFP8Type::UE8M0:
214 if (!(isRoundingModeRZ || isRoundingModeRP))
215 return emitOpError(
"Only RZ or RP rounding modes are supported for "
216 "conversions from f32x2 to .ue8m0x2 type");
218 return emitOpError(
"relu not supported for conversions to .ue8m0x2 type");
225 if (
getType() == ConvertFP8Type::UE8M0)
226 return emitOpError(
"Only .e4m3 or .e5m2 types are supported for "
227 "conversions from f16x2 to f8x2.");
233 using RndMode = NVVM::FPRoundingMode;
235 if (
getType() != ConvertFP8Type::UE8M0)
237 "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
240 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
241 return emitOpError(
"Only RZ and RP rounding modes are supported for "
242 "conversions from bf16x2 to f8x2.");
248 if (getInitVal() != 0)
249 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
254 auto eventId = getEventId();
255 auto maskedEventId = getMaskedEventId();
256 if (!maskedEventId && !eventId) {
257 return emitOpError() <<
"either `id` or `mask` must be set";
260 if (maskedEventId && eventId) {
261 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
265 if (eventId < 0 || eventId > 15) {
266 return emitOpError() <<
"`id` must be between 0 and 15";
270 return llvm::success();
276 std::optional<mlir::NVVM::MMATypes>
277 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
280 if (operandElType.
isF64())
281 return NVVM::MMATypes::f64;
282 if (operandElType.
isF16() || operandElType == half2Type)
283 return NVVM::MMATypes::f16;
284 if (operandElType.
isF32() && isAccumulator)
285 return NVVM::MMATypes::f32;
286 if (operandElType.
isF32() && !isAccumulator)
287 return NVVM::MMATypes::tf32;
288 if (llvm::isa<IntegerType>(operandElType)) {
290 return NVVM::MMATypes::s32;
294 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
295 if (structType.getBody().empty())
297 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
304 return (type == MMATypes::u4 || type == MMATypes::s4);
308 return (type == MMATypes::u8 || type == MMATypes::s8);
313 type == MMATypes::s32;
316 MMATypes MmaOp::accumPtxType() {
317 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
318 getODSOperands(2).getTypes().front(),
true);
319 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
323 MMATypes MmaOp::resultPtxType() {
324 std::optional<mlir::NVVM::MMATypes> val =
325 inferOperandMMAType(getResult().
getType(),
true);
326 assert(val.has_value() &&
"result PTX type should always be inferrable");
332 struct OperandFragment {
333 StringRef operandName;
334 StringRef ptxTypeAttr;
336 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
337 : operandName(name), ptxTypeAttr(ptxTypeName) {}
340 std::array<OperandFragment, 3> frags{
341 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
342 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
343 OperandFragment(
"C",
"")};
345 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
347 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
348 auto &frag = frags[fragIdx];
349 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
350 for (
auto operandIdx = varOperandSpec.first;
351 operandIdx < varOperandSpec.first + varOperandSpec.second;
353 frag.regs.push_back(this->getOperand(operandIdx));
354 if (operandIdx == 0) {
355 regTypes.push_back(this->getOperand(operandIdx).
getType());
358 std::optional<MMATypes> inferredType =
359 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
361 ignoreAttrNames.push_back(frag.ptxTypeAttr);
364 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
365 p <<
" " << frag.operandName;
371 for (
const auto &frag : frags) {
372 printMmaOperand(frag);
391 std::optional<MMAIntOverflow> intOverflow,
392 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
393 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
395 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
398 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
404 if (multiplicandPtxTypes) {
410 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
412 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
416 if (multiplicandLayouts) {
426 if (intOverflow.has_value())
429 if (b1Op.has_value())
434 MmaOp::getOperandSegmentSizeAttr(),
436 static_cast<int32_t>(operandB.size()),
437 static_cast<int32_t>(operandC.size())}));
445 struct OperandFragment {
446 std::optional<MMATypes> elemtype;
452 std::array<OperandFragment, 4> frags;
457 auto parseMmaOperand = [&](StringRef operandName,
458 OperandFragment &frag) -> LogicalResult {
469 if (parseMmaOperand(
"A", frags[0]).
failed())
471 if (parseMmaOperand(
"B", frags[1]).
failed())
473 if (parseMmaOperand(
"C", frags[2]).
failed())
488 if (operandTypes.size() != 3)
491 "expected one type for each operand segment but got " +
492 Twine(operandTypes.size()) +
" types");
494 auto &frag = frags[iter.index()];
495 frag.regTypes.resize(frag.regs.size(), iter.value());
499 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
506 frags[3].elemtype = inferOperandMMAType(resultType,
true);
508 std::array<StringRef, 2> names{
"multiplicandAPtxType",
509 "multiplicandBPtxType"};
510 for (
unsigned idx = 0; idx < names.size(); idx++) {
511 const auto &frag = frags[idx];
512 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
513 if (!frag.elemtype.has_value() && !attr.has_value()) {
516 "attribute " + names[idx] +
517 " is not provided explicitly and cannot be inferred");
519 if (!attr.has_value())
525 if (!namedAttributes.
empty())
529 static_cast<int32_t>(frags[0].regs.size()),
530 static_cast<int32_t>(frags[1].regs.size()),
531 static_cast<int32_t>(frags[2].regs.size()),
542 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
543 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
546 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
549 auto f16x2x2StructTy =
550 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
552 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
554 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
556 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
557 getShapeAttr().getK()};
563 AllowedShapes allowedShapes;
564 AllowedTypes expectedA;
565 AllowedTypes expectedB;
566 AllowedTypes expectedC;
571 if (mmaShape[0] == 16) {
573 Type multiplicandFragType;
574 switch (*getMultiplicandAPtxType()) {
577 multiplicandFragType = i32Ty;
578 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
579 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
583 multiplicandFragType = i32Ty;
584 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
585 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
589 multiplicandFragType = f16x2Ty;
590 expectedResult.push_back(f16x2x2StructTy);
591 expectedResult.push_back(f32x4StructTy);
605 return emitError(
"invalid shape or multiplicand type: " +
606 stringifyEnum(getMultiplicandAPtxType().value()));
610 expectedResult.push_back(s32x4StructTy);
611 expectedC.emplace_back(4, i32Ty);
612 multiplicandFragType = i32Ty;
614 expectedC.emplace_back(2, f16x2Ty);
615 expectedC.emplace_back(4, f32Ty);
618 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
619 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
620 expectedA.emplace_back(unitA, multiplicandFragType);
621 expectedB.emplace_back(unitB, multiplicandFragType);
622 allowedShapes.push_back({16, 8, kFactor});
623 allowedShapes.push_back({16, 8, kFactor * 2});
627 if (mmaShape[0] == 8) {
628 if (*getMultiplicandAPtxType() == MMATypes::f16) {
629 expectedA.emplace_back(2, f16x2Ty);
630 expectedB.emplace_back(2, f16x2Ty);
631 expectedResult.push_back(f16x2x4StructTy);
632 expectedResult.push_back(f32x8StructTy);
633 expectedC.emplace_back(4, f16x2Ty);
634 expectedC.emplace_back(8, f32Ty);
635 allowedShapes.push_back({8, 8, 4});
637 if (*getMultiplicandAPtxType() == MMATypes::f64) {
639 expectedA.emplace_back(1, f64Ty);
640 expectedB.emplace_back(1, f64Ty);
641 expectedC.emplace_back(2, f64Ty);
642 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
644 allowedShapes.push_back({8, 8, 4});
647 expectedA.push_back({i32Ty});
648 expectedB.push_back({i32Ty});
649 expectedC.push_back({i32Ty, i32Ty});
650 expectedResult.push_back(s32x2StructTy);
652 allowedShapes.push_back({8, 8, 32});
654 allowedShapes.push_back({8, 8, 16});
655 if (getMultiplicandAPtxType().value() == MMATypes::b1)
656 allowedShapes.push_back({8, 8, 128});
660 std::string errorMessage;
661 llvm::raw_string_ostream errorStream(errorMessage);
664 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
665 !llvm::is_contained(allowedShapes, mmaShape)) {
666 errorStream <<
"unimplemented variant for MMA shape <";
667 llvm::interleaveComma(mmaShape, errorStream);
669 return emitOpError(errorMessage);
673 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
676 auto spec = this->getODSOperandIndexAndLength(iter.index());
678 operand_type_begin() + spec.first +
680 bool match = llvm::is_contained(iter.value(), operandTySeg);
683 errorStream <<
"Could not match types for the "
684 << operandNames[iter.index()]
685 <<
" operands; expected one of ";
686 for (
const auto &x : iter.value()) {
687 errorStream << x.size() <<
"x" << x[0] <<
" ";
689 errorStream <<
"but got ";
690 llvm::interleaveComma(operandTySeg, errorStream);
691 return emitOpError(errorMessage);
696 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
697 return expectedResultType == getResult().getType();
700 <<
"Could not match allowed types for the result; expected one of ";
701 llvm::interleaveComma(expectedResult, errorStream);
702 errorStream <<
" but got " << getResult().getType();
703 return emitOpError(errorMessage);
707 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
708 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
716 if (!getIntOverflowBehavior())
717 return emitOpError(
"op requires " +
718 getIntOverflowBehaviorAttrName().strref() +
726 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
728 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
729 auto elementType = (type && type.getBody().size() == 2)
730 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
732 if (!elementType || elementType.getWidth() != 1)
733 return emitError(
"expected return type to be a two-element struct with "
734 "i1 as the second element");
739 NVVM::MMAFrag frag,
int nRow,
742 unsigned numberElements = 0;
746 if (type == NVVM::MMATypes::f16) {
748 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
752 }
else if (type == NVVM::MMATypes::f32) {
755 }
else if (type == NVVM::MMATypes::tf32) {
758 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
760 int parallelSize = 0;
761 if (frag == NVVM::MMAFrag::a)
763 if (frag == NVVM::MMAFrag::b)
767 if (parallelSize == 16)
770 else if (parallelSize == 8)
772 else if (parallelSize == 32)
774 }
else if (type == NVVM::MMATypes::s32) {
778 assert(numberElements != 0 && elementType !=
nullptr);
779 return std::make_pair(elementType, numberElements);
782 static std::pair<mlir::Type, unsigned>
786 if (frag == NVVM::MMAFrag::a) {
789 }
else if (frag == NVVM::MMAFrag::b) {
796 assert(nRow && nCol);
801 unsigned addressSpace =
802 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
803 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
804 addressSpace != NVVMMemorySpace::Shared)
805 return emitOpError(
"expected source pointer in memory "
808 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
809 getEltype(), getFrag()) == 0)
810 return emitOpError() <<
"invalid attribute combination";
813 Type dstType = LLVM::LLVMStructType::getLiteral(
816 return emitOpError(
"expected destination type is a structure of ")
817 << typeInfo.second <<
" elements of type " << typeInfo.first;
822 unsigned addressSpace =
823 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
824 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
825 addressSpace != NVVMMemorySpace::Shared)
826 return emitOpError(
"expected operands to be a source pointer in memory "
829 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
831 return emitOpError() <<
"invalid attribute combination";
834 if (getArgs().size() != typeInfo.second)
835 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
836 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
837 return operands.
getType() != typeInfo.first;
839 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
844 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
845 getLayoutB(), getEltypeA(),
847 return emitOpError() <<
"invalid attribute combination";
855 arguments.append(typeInfoA.second, typeInfoA.first);
856 arguments.append(typeInfoB.second, typeInfoB.first);
857 arguments.append(typeInfoC.second, typeInfoC.first);
858 unsigned numArgs = arguments.size();
859 if (getArgs().size() != numArgs)
860 return emitOpError() <<
"expected " << numArgs <<
" arguments";
861 for (
unsigned i = 0; i < numArgs; i++) {
862 if (getArgs()[i].
getType() != arguments[i])
863 return emitOpError() <<
"expected argument " << i <<
" to be of type "
866 Type dstType = LLVM::LLVMStructType::getLiteral(
869 return emitOpError(
"expected destination type is a structure of ")
870 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
876 if (m == 8 && n == 8) {
877 if (num != 1 && num != 2 && num != 4) {
878 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
881 if (getEltType() != LdStMatrixEltType::B16) {
882 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
884 }
else if (m == 8 && n == 16) {
885 if (num != 1 && num != 2 && num != 4) {
886 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
889 if (getLayout() != MMALayout::row) {
890 return emitOpError(
"expected layout to be row for 8x16 matrix");
892 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
893 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
894 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
895 "b8x16.b6x16_p32 for 8x16 matrix");
897 }
else if (m == 16 && n == 16) {
898 if (num != 1 && num != 2) {
899 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
902 if (getLayout() != MMALayout::col) {
903 return emitOpError(
"expected layout to be col for 16x16 matrix");
905 if (getEltType() != LdStMatrixEltType::B8 &&
906 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
907 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
908 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
909 "b8x16.b6x16_p32 for 16x16 matrix");
912 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
916 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
917 if (numElements == 1 &&
getType() != i32)
918 return emitOpError(
"expected destination type is i32");
919 if (numElements == 2 || numElements == 4) {
920 Type dstType = LLVM::LLVMStructType::getLiteral(
923 return emitOpError(
"expected destination type is a structure of ")
924 << numElements <<
" elements of type i32";
931 int numMatrix = getSources().size();
932 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
933 return emitOpError(
"expected num attribute to be 1, 2 or 4");
936 if (m == 8 && n == 8) {
937 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
938 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
940 }
else if (m == 16 && n == 8) {
941 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
942 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
944 if (getLayout() != NVVM::MMALayout::col) {
945 return emitOpError(
"expected layout to be col for 16x8 matrix");
948 return emitOpError(
"expected shape to be 8x8 or 16x8");
955 if (typeA == NVVM::WGMMATypes::tf32)
957 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
959 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
961 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
963 if (typeA == NVVM::WGMMATypes::b1)
969 NVVM::WGMMATypes typeA,
970 NVVM::WGMMATypes typeB) {
972 case NVVM::WGMMATypes::f16:
973 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
974 typeB == NVVM::WGMMATypes::f16)
977 case NVVM::WGMMATypes::tf32:
978 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
981 case NVVM::WGMMATypes::u8:
982 case NVVM::WGMMATypes::s8:
983 if (typeD == NVVM::WGMMATypes::s32 &&
984 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
987 case NVVM::WGMMATypes::b1:
988 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
991 case NVVM::WGMMATypes::bf16:
992 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
993 typeB == NVVM::WGMMATypes::bf16)
996 case NVVM::WGMMATypes::e4m3:
997 case NVVM::WGMMATypes::e5m2:
998 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
999 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1002 case WGMMATypes::f32:
1003 case WGMMATypes::s32:
1004 llvm_unreachable(
"unsupported input types");
1012 72, 80, 88, 96, 104, 112, 120, 128,
1013 136, 144, 152, 160, 168, 176, 184, 192,
1014 200, 208, 216, 224, 232, 240, 248, 256};
1016 80, 96, 112, 128, 144, 160,
1017 176, 192, 208, 224, 240, 256};
1019 case WGMMATypes::f16:
1020 case WGMMATypes::tf32:
1021 case WGMMATypes::bf16:
1022 case WGMMATypes::e4m3:
1023 case WGMMATypes::e5m2:
1024 if (llvm::is_contained(allowedN, sizeN))
1027 case WGMMATypes::u8:
1028 case WGMMATypes::s8:
1029 case WGMMATypes::b1:
1030 if (llvm::is_contained(allowedNshort, sizeN))
1033 case WGMMATypes::f32:
1034 case WGMMATypes::s32:
1035 llvm_unreachable(
"unsupported input types");
1042 Value outValue = getResults();
1043 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
1045 return emitOpError() <<
"expected results to be struct";
1046 int outputSize = stype.getBody().size();
1047 WGMMATypes typeD = getTypeD();
1048 WGMMATypes typeA = getTypeA();
1049 WGMMATypes typeB = getTypeB();
1051 for (
Type t : stype.getBody()) {
1052 if (t != stype.getBody().front())
1053 return emitOpError()
1054 <<
"all elements in struct must be same type but there is " << t;
1057 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1058 typeD != WGMMATypes::s32) {
1059 return emitOpError() <<
"does not support the given output type "
1060 << NVVM::stringifyWGMMATypes(typeD);
1062 if (typeD == WGMMATypes::s32 &&
1063 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1064 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
1068 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1069 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
1070 << NVVM::stringifyWGMMATypes(typeB)
1071 <<
", it is not supported.";
1076 return emitOpError() <<
"shape 'm' must be 64";
1081 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
1082 <<
" for input type "
1083 << NVVM::stringifyWGMMATypes(typeA);
1087 return emitOpError() <<
"has input type "
1088 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
1089 <<
getShape().getN() <<
", it is not supported.";
1096 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1097 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1098 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1099 return emitOpError()
1100 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1101 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
1102 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
1103 << stringifyWGMMATypes(typeB)
1104 <<
" requires transpose. However, this is only supported for: "
1105 << stringifyMMATypes(MMATypes::f16) <<
" and "
1106 << stringifyMMATypes(MMATypes::bf16);
1110 int expectedOutput = 0;
1111 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1112 expectedOutput =
getShape().getN() / 2;
1113 if (typeD == WGMMATypes::f16)
1114 expectedOutput =
getShape().getN() / 4;
1115 if (outputSize != expectedOutput) {
1116 return emitOpError() <<
"results " << expectedOutput
1117 <<
", however output struct has " << outputSize
1121 if (typeD != WGMMATypes::s32 &&
1122 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1123 NVVM::MMAIntOverflow::satfinite) {
1124 return emitOpError()
1125 <<
" `satfinite` can be only used with s32 accumulator, however "
1126 "the current accumulator is "
1127 << NVVM::stringifyWGMMATypes(typeD);
1133 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1136 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1138 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1140 int expectedOutputRegisters = 0;
1141 if (getTypeD() == WGMMATypes::f16)
1142 expectedOutputRegisters =
getShape().getN() / 4;
1144 expectedOutputRegisters =
getShape().getN() / 2;
1147 llvm::raw_string_ostream ss(ptx);
1152 << ((expectedOutputRegisters * 2) + 2)
1154 "wgmma.mma_async.sync.aligned.m"
1155 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1156 << stringifyWGMMATypes(getTypeA()) <<
"."
1157 << stringifyWGMMATypes(getTypeB());
1158 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1159 NVVM::MMAIntOverflow::satfinite)
1163 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1164 ss <<
"$" << regCnt;
1165 if (regCnt != expectedOutputRegisters - 1)
1171 regCnt = (regCnt * 2);
1172 ss <<
" $" << (regCnt) <<
","
1173 <<
" $" << (regCnt + 1) <<
","
1175 if (getTypeD() != WGMMATypes::s32) {
1176 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1180 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1187 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1191 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1198 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1200 if (getTypeD() != WGMMATypes::s32) {
1201 asmValues.push_back(
1202 {makeConstantI32(rewriter,
1203 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1205 asmValues.push_back(
1206 {makeConstantI32(rewriter,
1207 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1211 asmValues.push_back(
1212 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1214 asmValues.push_back(
1215 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1222 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1223 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1224 if (getKind() == NVVM::ProxyKind::GENERIC)
1225 return emitOpError() <<
"generic proxy not a supported proxy kind";
1226 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1227 return emitOpError() <<
"async_shared fence requires space attribute";
1229 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1230 return emitOpError() <<
"only async_shared fence can have space attribute";
1236 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1237 return emitOpError(
"uni-directional proxies only support generic for "
1238 "from_proxy attribute");
1240 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1241 return emitOpError(
"uni-directional proxies only support tensormap "
1242 "for to_proxy attribute");
1248 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1249 return emitOpError(
"uni-directional proxies only support generic for "
1250 "from_proxy attribute");
1252 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1253 return emitOpError(
"uni-directional proxies only support tensormap "
1254 "for to_proxy attribute");
1260 if (getRegCount() % 8)
1261 return emitOpError(
"new register size must be multiple of 8");
1262 if (getRegCount() < 24 || getRegCount() > 256)
1263 return emitOpError(
"new register size must be in between 24 to 256");
1268 if (getNumberOfThreads() && !getBarrierId())
1270 "barrier id is missing, it should be set between 0 to 15");
1275 auto mc = getMulticast();
1277 using SH = Tcgen05CpShape;
1278 using MC = Tcgen05CpMulticast;
1280 case SH::SHAPE_128x256b:
1281 case SH::SHAPE_128x128b:
1282 case SH::SHAPE_4x256b:
1284 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1286 case SH::SHAPE_64x128b:
1287 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1288 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1289 "warpx2_02_13 for tcgen05.cp Op");
1291 case SH::SHAPE_32x128b:
1292 if (mc != MC::WARPX4)
1294 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1301 if (getKind() == NVVM::MatchSyncKind::all) {
1302 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1303 if (!type || type.getBody().size() != 2 ||
1304 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1305 return emitOpError(
"match.sync 'all' returns a two element struct with "
1306 "first element as i32 and second element as i1");
1309 if (!
getType().isInteger(32)) {
1310 return emitOpError(
"match.sync 'any' returns an i32");
1317 if (getKind() == NVVM::VoteSyncKind::ballot) {
1318 if (!
getType().isInteger(32)) {
1319 return emitOpError(
"vote.sync 'ballot' returns an i32");
1322 if (!
getType().isInteger(1)) {
1323 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1330 using MemSpace = NVVM::NVVMMemorySpace;
1331 using CacheLevel = NVVM::PrefetchCacheLevel;
1333 unsigned addressSpace =
1334 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1335 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1336 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1338 if (getTensormap() && cacheLevel)
1339 return emitOpError(
"cannot specify both tensormap and cache level");
1341 if (getTensormap()) {
1342 if (addressSpace != MemSpace::Generic &&
1343 addressSpace != MemSpace::Constant) {
1345 "prefetch tensormap requires a generic or constant pointer");
1348 if (evictPriority) {
1350 "prefetch tensormap does not support eviction priority");
1353 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1355 "in_param_space can only be specified for a generic pointer");
1358 }
else if (cacheLevel) {
1359 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1360 addressSpace != MemSpace::Local) {
1361 return emitOpError(
"prefetch to cache level requires a generic, global, "
1362 "or local pointer");
1366 if (*cacheLevel != CacheLevel::L1) {
1368 "unsupported cache level, the only supported uniform "
1369 "cache level is L1");
1372 if (addressSpace != MemSpace::Generic) {
1374 "prefetch to uniform cache requires a generic pointer");
1378 if (evictPriority) {
1379 if (*cacheLevel != CacheLevel::L2)
1381 "cache eviction priority supported only for cache level L2");
1383 if (addressSpace != MemSpace::Global)
1384 return emitOpError(
"cache eviction priority requires a global pointer");
1386 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1387 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1389 "unsupported cache eviction priority, only evict_last and "
1390 "evict_normal are supported");
1394 return emitOpError(
"predicate supported only on prefetch tensormap");
1398 "requires specification of either cache level or tensormap");
1405 switch (getQueryType()) {
1406 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1408 return emitOpError(
"is_canceled query type returns an i1");
1410 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1411 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1412 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1413 if (!
getType().isInteger(32)) {
1414 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
1415 "get_first_cta_id_z query types return an i32");
1424 static llvm::Value *
1426 llvm::Value *result,
1428 unsigned sizeInBits,
1430 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1432 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1433 if (mask != 0xffffffffu)
1434 field = builder.CreateAnd(field, builder.getInt32(mask));
1436 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1437 field = builder.CreateShl(field, start);
1439 return builder.CreateOr(result, field);
1442 void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1443 LLVM::ModuleTranslation &mt,
1444 llvm::IRBuilderBase &builder) {
1445 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1446 llvm::Value *smemDesc = builder.getInt64(0);
1449 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1451 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1453 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1457 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1459 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1461 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1463 mt.mapValue(thisOp.getRes()) = smemDesc;
1470 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1471 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1473 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1474 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1477 CpAsyncOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
1481 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1482 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1483 switch (cpAsyncOp.getSize()) {
1491 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1496 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1500 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1501 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1503 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1509 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1510 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1515 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1516 args.push_back(mt.lookupValue(thisOp.getSize()));
1519 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1520 llvm::Value *i64Unused =
1522 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1523 args.push_back(builder.getInt1(hasCacheHint));
1525 return {id, std::move(args)};
1529 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1530 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1533 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1536 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1537 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1538 args.push_back(mt.lookupValue(thisOp.getSize()));
1541 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1542 llvm::Value *i64Unused =
1544 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1545 args.push_back(builder.getInt1(hasCacheHint));
1548 if (
mlir::Value byteMask = thisOp.getByteMask()) {
1549 args.push_back(mt.lookupValue(byteMask));
1550 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1553 return {id, std::move(args)};
1557 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1558 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1562 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1564 for (
auto v : thisOp.getCoordinates())
1565 args.push_back(mt.lookupValue(v));
1566 for (
auto v : thisOp.getIm2colOffsets())
1567 args.push_back(mt.lookupValue(v));
1570 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1571 llvm::Value *i64Unused =
1573 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1574 args.push_back(builder.getInt1(hasCacheHint));
1576 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1578 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1579 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1580 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1581 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1582 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1584 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1585 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1586 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1588 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1589 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1590 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1592 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1593 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1594 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1595 {NI, NI, NI, NI, NI,
1596 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1598 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1599 "TMALoadModes must match number of rows in IDTable");
1600 size_t mode =
static_cast<size_t>(thisOp.getMode());
1601 size_t dim = thisOp.getCoordinates().size();
1603 if (
id == llvm::Intrinsic::not_intrinsic)
1604 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1606 return {id, std::move(args)};
1610 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1611 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1612 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1616 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1617 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1619 for (
auto v : thisOp.getCoordinates())
1620 args.push_back(mt.lookupValue(v));
1623 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1624 llvm::Value *i64Unused =
1626 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1627 args.push_back(builder.getInt1(hasCacheHint));
1629 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1631 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1632 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1633 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1634 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1635 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1636 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1637 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1638 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1639 {NI, NI, NI, NI, NI,
1640 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1642 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
1643 "TMAStoreModes must match number of rows in IDTable");
1644 size_t mode =
static_cast<size_t>(thisOp.getMode());
1645 size_t dim = thisOp.getCoordinates().size();
1647 if (
id == llvm::Intrinsic::not_intrinsic)
1649 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
1651 return {id, std::move(args)};
1654 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1655 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1657 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1658 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1659 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1661 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1665 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1667 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1669 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1671 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1673 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1675 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1680 int tensorDims, NVVM::TMAReduxKind
kind,
bool isIm2Col) {
1681 using RedTy = NVVM::TMAReduxKind;
1700 llvm_unreachable(
"Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1705 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1706 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1707 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1709 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1710 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1711 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1714 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1715 NVVM::SaturationMode sat,
bool hasRelu) {
1716 using RndMode = NVVM::FPRoundingMode;
1717 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1726 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
1730 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1731 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1732 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1735 ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type,
bool hasRelu) {
1737 case NVVM::ConvertFP6Type::E2M3:
1739 case NVVM::ConvertFP6Type::E3M2:
1742 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
1745 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1746 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1747 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1749 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1750 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1751 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1754 ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1755 NVVM::FPRoundingMode rnd,
1756 NVVM::SaturationMode sat,
bool hasRelu) {
1757 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1758 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1759 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1762 case NVVM::ConvertFP8Type::E4M3:
1764 case NVVM::ConvertFP8Type::E5M2:
1766 case NVVM::ConvertFP8Type::UE8M0:
1767 if (hasRoundingModeRZ)
1769 else if (hasRoundingModeRP)
1772 llvm_unreachable(
"Invalid conversion in CvtFloatToF8x2Op");
1775 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1776 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1777 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1780 ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
bool hasRelu) {
1782 case NVVM::ConvertFP8Type::E4M3:
1784 case NVVM::ConvertFP8Type::E5M2:
1787 llvm_unreachable(
"Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1791 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1792 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1793 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1796 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1797 NVVM::SaturationMode sat) {
1798 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1800 case NVVM::FPRoundingMode::RZ:
1802 case NVVM::FPRoundingMode::RP:
1805 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
1810 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
1811 LLVM::ModuleTranslation &mt,
1813 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1814 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1816 bool isShared = as == NVVMMemorySpace::Shared;
1817 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
1821 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1822 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1824 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1825 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1829 args.push_back(mt.lookupValue(curOp.getAddr()));
1830 args.push_back(mt.lookupValue(curOp.getNCols()));
1836 Operation &op, LLVM::ModuleTranslation &mt,
1838 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1839 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
1840 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1841 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1844 args.push_back(mt.lookupValue(curOp.getTaddr()));
1845 args.push_back(mt.lookupValue(curOp.getNCols()));
1850 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1851 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1852 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1854 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1855 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1856 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1859 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
1860 LLVM::ModuleTranslation &mt,
1862 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1863 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1865 bool isShared = as == NVVMMemorySpace::Shared;
1866 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
1867 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
1874 args.push_back(mt.lookupValue(curOp.getAddr()));
1876 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1881 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1882 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1884 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1885 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1886 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1888 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1890 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1891 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1892 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1893 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1894 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1898 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1899 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
1900 auto srcFmt = curOp.getSrcFormat();
1901 auto mc = curOp.getMulticast();
1903 switch (curOp.getShape()) {
1904 case Tcgen05CpShape::SHAPE_128x256b:
1906 case Tcgen05CpShape::SHAPE_128x128b:
1908 case Tcgen05CpShape::SHAPE_4x256b:
1910 case Tcgen05CpShape::SHAPE_32x128b:
1912 case Tcgen05CpShape::SHAPE_64x128b:
1913 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1917 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
1924 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1926 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1932 LogicalResult result = success();
1933 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1934 result =
emitError(
"shape 16x32bx2 requires offset argument");
1936 auto resTy = getRes().getType();
1937 unsigned resLen = isa<VectorType>(resTy)
1938 ? llvm::cast<VectorType>(resTy).getNumElements()
1941 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
1942 "{1} in tcgen05.ld Op",
1943 resLen, stringifyEnum(
getShape())));
1949 LogicalResult result = success();
1950 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1951 result =
emitError(
"shape 16x32bx2 requires offset argument");
1953 auto valTy = getVal().getType();
1954 unsigned valLen = isa<VectorType>(valTy)
1955 ? llvm::cast<VectorType>(valTy).getNumElements()
1958 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
1959 "{1} in tcgen05.st Op",
1960 valLen, stringifyEnum(
getShape())));
1970 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
1971 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1972 rangeAttr.getLower(), rangeAttr.getUpper()});
1977 llvm::IRBuilderBase &builder) {
1978 return builder.CreateBitCast(arg,
1979 llvm::Type::getInt32Ty(builder.getContext()));
1983 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1984 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1987 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1988 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1989 args.push_back(mt.lookupValue(curOp.getC()));
1991 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1992 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1993 unsigned type = (isASigned << 1) | isBSigned;
1995 llvm::Intrinsic::nvvm_idp4a_u_u,
1996 llvm::Intrinsic::nvvm_idp4a_u_s,
1997 llvm::Intrinsic::nvvm_idp4a_s_u,
1998 llvm::Intrinsic::nvvm_idp4a_s_s,
2000 return {ids[type], args};
2004 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2005 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2008 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2009 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2010 args.push_back(builder.getInt1(curOp.getBHi()));
2011 args.push_back(mt.lookupValue(curOp.getC()));
2013 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2014 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2015 unsigned type = (isASigned << 1) | isBSigned;
2017 llvm::Intrinsic::nvvm_idp2a_u_u,
2018 llvm::Intrinsic::nvvm_idp2a_u_s,
2019 llvm::Intrinsic::nvvm_idp2a_s_u,
2020 llvm::Intrinsic::nvvm_idp2a_s_s,
2022 return {ids[type], args};
2026 llvm::IRBuilderBase &builder) {
2027 return builder.CreateAddrSpaceCast(
2030 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2034 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2035 LLVM::ModuleTranslation &mt,
2036 llvm::IRBuilderBase &builder) {
2037 using MemSpace = NVVM::NVVMMemorySpace;
2038 using CacheLevel = NVVM::PrefetchCacheLevel;
2040 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2041 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2042 op.getEvictPriority();
2043 unsigned addressSpace =
2044 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2048 llvm::Value *addr = mt.lookupValue(op.getAddr());
2052 if (op.getTensormap())
2053 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2055 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
2057 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2058 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2060 if (evictPriority && *cacheLevel == CacheLevel::L2) {
2061 switch (*evictPriority) {
2062 case NVVM::CacheEvictionPriority::EvictLast:
2063 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2064 case NVVM::CacheEvictionPriority::EvictNormal:
2065 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2067 llvm_unreachable(
"Invalid cache eviction priority");
2071 switch (
static_cast<MemSpace
>(addressSpace)) {
2072 case MemSpace::Generic:
2073 return *cacheLevel == CacheLevel::L1
2076 case MemSpace::Global:
2077 return *cacheLevel == CacheLevel::L1
2079 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2081 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2082 case MemSpace::Local:
2083 return *cacheLevel == CacheLevel::L1
2085 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2087 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2089 llvm_unreachable(
"Invalid pointer address space");
2093 bool NVVM::InlinePtxOp::getAsmValues(
2097 for (
auto arg : getReadWriteArgs())
2099 for (
auto arg : getResults())
2101 for (
auto arg : getReadOnlyArgs())
2108 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2109 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2110 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2112 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2113 args.push_back(mt.lookupValue(curOp.getMbarrier()));
2116 curOp.getMulticast()
2118 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2119 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2121 return {intrinsicID, args};
2124 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2125 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2126 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2128 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2132 switch (curOp.getQueryType()) {
2133 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2135 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2137 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2138 intrinsicID = llvm::Intrinsic::
2139 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2141 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2142 intrinsicID = llvm::Intrinsic::
2143 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2145 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2146 intrinsicID = llvm::Intrinsic::
2147 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2150 return {intrinsicID, args};
2158 void NVVMDialect::initialize() {
2161 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2164 #define GET_ATTRDEF_LIST
2165 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2170 allowUnknownOperations();
2171 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2172 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2175 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
2177 StringAttr attrName = attr.
getName();
2179 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2180 if (!isa<LLVM::LLVMFuncOp>(op)) {
2181 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
2182 <<
"' attribute attached to unexpected op";
2187 if (attrName == NVVMDialect::getMaxntidAttrName() ||
2188 attrName == NVVMDialect::getReqntidAttrName() ||
2189 attrName == NVVMDialect::getClusterDimAttrName()) {
2190 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
2191 if (!values || values.empty() || values.size() > 3) {
2194 <<
"' attribute must be integer array with maximum 3 index";
2199 if (attrName == NVVMDialect::getMinctasmAttrName() ||
2200 attrName == NVVMDialect::getMaxnregAttrName() ||
2201 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2202 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
2204 <<
"'" << attrName <<
"' attribute must be integer constant";
2208 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2209 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
2210 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
2212 <<
"'" << attrName <<
"' attribute must be used along with "
2213 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
2214 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
2221 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
2222 unsigned regionIndex,
2225 auto funcOp = dyn_cast<FunctionOpInterface>(op);
2229 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
2230 StringAttr attrName = argAttr.
getName();
2231 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2235 <<
"' attribute must be present only on kernel arguments";
2237 if (!isa<UnitAttr>(argAttr.
getValue()))
2238 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
2239 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2242 <<
"' attribute requires the argument to also have attribute '"
2243 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
2254 unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
2255 return static_cast<unsigned>(getValue());
2258 bool NVVMMemorySpaceAttr::isValidLoad(
2259 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2260 const ::mlir::DataLayout *dataLayout,
2266 bool NVVMMemorySpaceAttr::isValidStore(
2267 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2268 const ::mlir::DataLayout *dataLayout,
2274 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2275 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
2276 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2279 assert(
false &&
"unimplemented, see TODO in the source.");
2283 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2284 Type type, ptr::AtomicOrdering successOrdering,
2285 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2286 const ::mlir::DataLayout *dataLayout,
2289 assert(
false &&
"unimplemented, see TODO in the source.");
2293 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2297 assert(
false &&
"unimplemented, see TODO in the source.");
2301 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2306 assert(
false &&
"unimplemented, see TODO in the source.");
2315 int optLevel, StringRef triple, StringRef chip,
2316 StringRef features, DictionaryAttr flags,
2317 ArrayAttr files,
bool verifyTarget) {
2318 if (optLevel < 0 || optLevel > 3) {
2319 emitError() <<
"The optimization level must be a number between 0 and 3.";
2322 if (triple.empty()) {
2323 emitError() <<
"The target triple cannot be empty.";
2327 emitError() <<
"The target chip cannot be empty.";
2331 return mlir::isa_and_nonnull<StringAttr>(attr);
2333 emitError() <<
"All the elements in the `link` array must be strings.";
2339 LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
2340 if (!getVerifyTarget())
2343 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2346 "NVVM target attribute must be attached to a GPU module");
2350 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2353 "Minimum NVVM target SM version is sm_20");
2357 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2360 op->emitOpError() <<
"is not supported on " << getChip();
2361 return WalkResult::interrupt();
2370 #define GET_OP_CLASSES
2371 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2373 #define GET_ATTRDEF_CLASSES
2374 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1241::ArityGroupAndKind::Kind kind
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)