31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
305LogicalResult MBarrierExpectTxOp::verify() {
309LogicalResult MBarrierCompleteTxOp::verify() {
313LogicalResult MBarrierTestWaitOp::verify() {
317LogicalResult MBarrierTryWaitOp::verify() {
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
326 return emitError(
"Relu not supported with rna rounding mode.");
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
343 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 <<
" types are supported for conversions from f32x2 to f6x2.";
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
359 bool hasRelu = getRelu();
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
366 if (!isRoundingModeRN) {
367 return emitOpError(
"Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
370 << mlir::Float8E5M2Type::get(ctx) <<
" types";
373 return emitOpError(
"Only SATFINITE saturation mode is supported "
376 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
377 << mlir::Float8E5M2Type::get(ctx) <<
" types";
381 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError(
"Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
388 return emitOpError(
"relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
395 << mlir::Float8E4M3FNType::get(ctx) <<
", "
396 << mlir::Float8E5M2Type::get(ctx) <<
", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
399 "supported for conversions from f32x2 to f8x2";
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
408 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
409 << mlir::Float8E5M2Type::get(ctx)
410 <<
" types are supported for conversions from f16x2 to f8x2.";
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
417 using SatMode = NVVM::SaturationMode;
419 bool isRoundingModeRN = getRnd() == RndMode::RN;
420 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
421 bool isRoundingModeRP = getRnd() == RndMode::RP;
422 bool isSatFinite = getSat() == SatMode::SATFINITE;
423 bool hasRelu = getRelu();
428 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
430 if (!isRoundingModeRN)
431 return emitOpError(
"Only RN rounding mode is supported for "
432 "conversions from bf16x2 to ")
433 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
434 << mlir::Float8E5M2Type::get(ctx) <<
" types";
436 return emitOpError(
"Only SATFINITE saturation mode is supported "
437 "for conversions from bf16x2 to ")
438 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
439 << mlir::Float8E5M2Type::get(ctx) <<
" types";
442 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
443 if (!(isRoundingModeRZ || isRoundingModeRP))
444 return emitOpError(
"Only RZ and RP rounding modes are supported for "
445 "conversions from bf16x2 to ")
446 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
448 return emitOpError(
"relu not supported for conversions to ")
449 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
453 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
458LogicalResult ConvertF32x2ToF4x2Op::verify() {
461 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
463 << mlir::Float4E2M1FNType::get(ctx)
464 <<
" type is supported for conversions from f32x2 to f4x2.";
469LogicalResult ConvertF8x2ToF16x2Op::verify() {
472 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
474 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
475 << mlir::Float8E5M2Type::get(ctx)
476 <<
" types are supported for conversions from f8x2 to f16x2.";
481LogicalResult ConvertF8x2ToBF16x2Op::verify() {
483 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
485 << mlir::Float8E8M0FNUType::get(ctx)
486 <<
" type is supported for conversions from f8x2 to bf16x2.";
491LogicalResult ConvertF6x2ToF16x2Op::verify() {
494 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
496 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
497 << mlir::Float6E3M2FNType::get(ctx)
498 <<
" types are supported for conversions from f6x2 to f16x2.";
503LogicalResult ConvertF4x2ToF16x2Op::verify() {
506 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
508 << mlir::Float4E2M1FNType::get(ctx)
509 <<
" type is supported for conversions from f4x2 to f16x2.";
514LogicalResult PermuteOp::verify() {
515 using Mode = NVVM::PermuteMode;
516 bool hasHi =
static_cast<bool>(getHi());
523 return emitError(
"mode '") << getMode() <<
"' requires 'hi' operand.";
531 << getMode() <<
"' does not accept 'hi' operand.";
546 static constexpr FPRoundingMode validRndModes[] = {
547 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
549 if (!llvm::is_contained(validRndModes, rnd)) {
551 "Only RN, RZ, and RS rounding modes are supported for "
552 "conversions from f32x2 to ")
556 if (rnd == FPRoundingMode::RS) {
557 if (!hasRandomBits) {
558 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
563 "random_bits not supported for RN and RZ rounding modes.");
570LogicalResult ConvertF32x2ToF16x2Op::verify() {
572 getRandomBits() ?
true :
false, *
this);
575LogicalResult ConvertF32x2ToBF16x2Op::verify() {
577 getRandomBits() ?
true :
false, *
this);
580LogicalResult ConvertF32x4ToF8x4Op::verify() {
583 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
585 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
586 << mlir::Float8E5M2Type::get(ctx)
587 <<
" types are supported for conversions from f32x4 to f8x4.";
592LogicalResult ConvertF32x4ToF6x4Op::verify() {
595 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
597 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
598 << mlir::Float6E3M2FNType::get(ctx)
599 <<
" types are supported for conversions from f32x4 to f6x4.";
604LogicalResult ConvertF32x4ToF4x4Op::verify() {
607 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
608 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
609 <<
" type is supported for conversions from "
615LogicalResult BulkStoreOp::verify() {
616 if (getInitVal() != 0)
617 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
621LogicalResult PMEventOp::verify() {
622 auto eventId = getEventId();
623 auto maskedEventId = getMaskedEventId();
624 if (!maskedEventId && !eventId) {
625 return emitOpError() <<
"either `id` or `mask` must be set";
628 if (maskedEventId && eventId) {
629 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
633 if (eventId < 0 || eventId > 15) {
634 return emitOpError() <<
"`id` must be between 0 and 15";
638 return llvm::success();
644std::optional<mlir::NVVM::MMATypes>
645MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
647 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
648 if (operandElType.
isF64())
649 return NVVM::MMATypes::f64;
650 if (operandElType.
isF16() || operandElType == half2Type)
651 return NVVM::MMATypes::f16;
652 if (operandElType.
isF32() && isAccumulator)
653 return NVVM::MMATypes::f32;
654 if (operandElType.
isF32() && !isAccumulator)
655 return NVVM::MMATypes::tf32;
656 if (llvm::isa<IntegerType>(operandElType)) {
658 return NVVM::MMATypes::s32;
662 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
663 if (structType.getBody().empty())
665 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
672 return (type == MMATypes::u4 || type == MMATypes::s4);
676 return (type == MMATypes::u8 || type == MMATypes::s8);
681 type == MMATypes::s32;
684MMATypes MmaOp::accumPtxType() {
685 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
686 getODSOperands(2).getTypes().front(),
true);
687 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
691MMATypes MmaOp::resultPtxType() {
692 std::optional<mlir::NVVM::MMATypes> val =
693 inferOperandMMAType(getResult().
getType(),
true);
694 assert(val.has_value() &&
"result PTX type should always be inferrable");
700 struct MMAOperandFragment {
701 StringRef operandName;
702 StringRef ptxTypeAttr;
703 SmallVector<Value, 4> regs;
704 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
705 : operandName(name), ptxTypeAttr(ptxTypeName) {}
708 std::array<MMAOperandFragment, 3> frags{
709 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
710 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
711 MMAOperandFragment(
"C",
"")};
713 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
715 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
716 auto &frag = frags[fragIdx];
717 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
718 for (
auto operandIdx = varOperandSpec.first;
719 operandIdx < varOperandSpec.first + varOperandSpec.second;
721 frag.regs.push_back(this->getOperand(operandIdx));
722 if (operandIdx == 0) {
723 regTypes.push_back(this->getOperand(operandIdx).
getType());
726 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
727 regTypes.back(), fragIdx >= 2);
729 ignoreAttrNames.push_back(frag.ptxTypeAttr);
732 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
733 p <<
" " << frag.operandName;
739 for (
const auto &frag : frags) {
740 printMmaOperand(frag);
749 frags[1].regs[0].getType(),
750 frags[2].regs[0].getType()},
759 std::optional<MMAIntOverflow> intOverflow,
760 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
761 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
763 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
768 result.addOperands(operandA);
769 result.addOperands(operandB);
770 result.addOperands(operandC);
772 if (multiplicandPtxTypes) {
773 result.addAttribute(
"multiplicandAPtxType",
774 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
775 result.addAttribute(
"multiplicandBPtxType",
776 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
778 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
779 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
780 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
781 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
784 if (multiplicandLayouts) {
785 result.addAttribute(
"layoutA",
786 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
787 result.addAttribute(
"layoutB",
788 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
790 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
791 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
794 if (intOverflow.has_value())
795 result.addAttribute(
"intOverflowBehavior",
796 MMAIntOverflowAttr::get(ctx, *intOverflow));
797 if (b1Op.has_value())
798 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
800 result.addTypes(resultType);
802 MmaOp::getOperandSegmentSizeAttr(),
804 static_cast<int32_t>(operandB.size()),
805 static_cast<int32_t>(operandC.size())}));
813 struct MMAOperandFragment {
814 std::optional<MMATypes> elemtype;
815 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
816 SmallVector<Type> regTypes;
820 std::array<MMAOperandFragment, 4> frags;
826 MMAOperandFragment &frag) -> LogicalResult {
856 if (operandTypes.size() != 3)
859 "expected one type for each operand segment but got " +
860 Twine(operandTypes.size()) +
" types");
861 for (
const auto &iter : llvm::enumerate(operandTypes)) {
862 auto &frag = frags[iter.index()];
863 frag.regTypes.resize(frag.regs.size(), iter.value());
867 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
874 frags[3].elemtype = inferOperandMMAType(resultType,
true);
876 std::array<StringRef, 2> names{
"multiplicandAPtxType",
877 "multiplicandBPtxType"};
878 for (
unsigned idx = 0; idx < names.size(); idx++) {
879 const auto &frag = frags[idx];
880 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
881 if (!frag.elemtype.has_value() && !attr.has_value()) {
884 "attribute " + names[idx] +
885 " is not provided explicitly and cannot be inferred");
887 if (!attr.has_value())
889 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
892 result.addTypes(resultType);
893 if (!namedAttributes.
empty())
894 result.addAttributes(namedAttributes);
895 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
897 static_cast<int32_t>(frags[0].regs.size()),
898 static_cast<int32_t>(frags[1].regs.size()),
899 static_cast<int32_t>(frags[2].regs.size()),
904LogicalResult MmaOp::verify() {
906 auto f16Ty = Float16Type::get(context);
907 auto i32Ty = IntegerType::get(context, 32);
908 auto f16x2Ty = VectorType::get(2, f16Ty);
909 auto f32Ty = Float32Type::get(context);
910 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
911 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
914 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
917 auto f16x2x2StructTy =
918 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
920 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
922 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
924 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
925 getShapeAttr().getK()};
931 AllowedShapes allowedShapes;
932 AllowedTypes expectedA;
933 AllowedTypes expectedB;
934 AllowedTypes expectedC;
939 if (mmaShape[0] == 16) {
941 Type multiplicandFragType;
942 switch (*getMultiplicandAPtxType()) {
945 multiplicandFragType = i32Ty;
946 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
947 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
951 multiplicandFragType = i32Ty;
952 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
953 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
957 multiplicandFragType = f16x2Ty;
958 expectedResult.push_back(f16x2x2StructTy);
959 expectedResult.push_back(f32x4StructTy);
973 return emitError(
"invalid shape or multiplicand type: ")
974 << getMultiplicandAPtxType().value();
978 expectedResult.push_back(s32x4StructTy);
979 expectedC.emplace_back(4, i32Ty);
980 multiplicandFragType = i32Ty;
982 expectedC.emplace_back(2, f16x2Ty);
983 expectedC.emplace_back(4, f32Ty);
986 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
987 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
988 expectedA.emplace_back(unitA, multiplicandFragType);
989 expectedB.emplace_back(unitB, multiplicandFragType);
990 allowedShapes.push_back({16, 8, kFactor});
991 allowedShapes.push_back({16, 8, kFactor * 2});
993 if (resultPtxType() != accumPtxType())
998 if (mmaShape[0] == 8) {
999 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1000 expectedA.emplace_back(2, f16x2Ty);
1001 expectedB.emplace_back(2, f16x2Ty);
1002 expectedResult.push_back(f16x2x4StructTy);
1003 expectedResult.push_back(f32x8StructTy);
1004 expectedC.emplace_back(4, f16x2Ty);
1005 expectedC.emplace_back(8, f32Ty);
1006 allowedShapes.push_back({8, 8, 4});
1008 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1009 Type f64Ty = Float64Type::get(context);
1010 expectedA.emplace_back(1, f64Ty);
1011 expectedB.emplace_back(1, f64Ty);
1012 expectedC.emplace_back(2, f64Ty);
1013 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1015 allowedShapes.push_back({8, 8, 4});
1018 expectedA.push_back({i32Ty});
1019 expectedB.push_back({i32Ty});
1020 expectedC.push_back({i32Ty, i32Ty});
1021 expectedResult.push_back(s32x2StructTy);
1023 allowedShapes.push_back({8, 8, 32});
1025 allowedShapes.push_back({8, 8, 16});
1026 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1027 allowedShapes.push_back({8, 8, 128});
1031 std::string errorMessage;
1032 llvm::raw_string_ostream errorStream(errorMessage);
1035 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1036 !llvm::is_contained(allowedShapes, mmaShape)) {
1037 errorStream <<
"unimplemented variant for MMA shape <";
1038 llvm::interleaveComma(mmaShape, errorStream);
1044 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1045 for (
const auto &iter : llvm::enumerate(
1047 auto spec = this->getODSOperandIndexAndLength(iter.index());
1049 operand_type_begin() + spec.first +
1051 bool match = llvm::is_contained(iter.value(), operandTySeg);
1054 errorStream <<
"Could not match types for the "
1055 << operandNames[iter.index()]
1056 <<
" operands; expected one of ";
1057 for (
const auto &x : iter.value()) {
1058 errorStream << x.size() <<
"x" << x[0] <<
" ";
1060 errorStream <<
"but got ";
1061 llvm::interleaveComma(operandTySeg, errorStream);
1067 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1068 return expectedResultType == getResult().getType();
1071 <<
"Could not match allowed types for the result; expected one of ";
1072 llvm::interleaveComma(expectedResult, errorStream);
1073 errorStream <<
" but got " << getResult().getType();
1078 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1079 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1087 if (!getIntOverflowBehavior())
1089 getIntOverflowBehaviorAttrName().strref() +
1097 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1098 getMultiplicandAPtxType() == MMATypes::f16);
1100 if (!isM8N8K4_F16) {
1102 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1103 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1104 "layoutB = #nvvm.mma_layout<col> for shape <")
1105 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1106 <<
"> with element types " << *getMultiplicandAPtxType() <<
" and "
1107 << *getMultiplicandBPtxType()
1108 <<
". Only m8n8k4 with f16 supports other layouts.";
1115MMATypes MmaSpOp::accumPtxType() {
1116 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1117 getODSOperands(2).getTypes().front(),
true);
1118 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1122MMATypes MmaSpOp::resultPtxType() {
1123 std::optional<mlir::NVVM::MMATypes> val =
1124 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1125 assert(val.has_value() &&
"result PTX type should always be inferrable");
1131 llvm::IRBuilderBase &builder) {
1132 auto thisOp = cast<NVVM::MmaSpOp>(op);
1140 auto intId = MmaSpOp::getIntrinsicID(
1141 thisOp.getShape().getM(), thisOp.getShape().getN(),
1142 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1143 thisOp.getOrderedMetadata(), thisOp.getKind(),
1144 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1145 thisOp.accumPtxType(), thisOp.resultPtxType());
1147 return {intId, args};
1152 struct MMAOperandFragment {
1153 StringRef operandName;
1154 StringRef ptxTypeAttr;
1155 SmallVector<Value, 4> regs;
1156 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1157 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1160 std::array<MMAOperandFragment, 5> frags{
1161 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1162 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1163 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1164 MMAOperandFragment(
"selector",
"")};
1166 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1169 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1170 auto &frag = frags[fragIdx];
1171 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1172 for (
auto operandIdx = varOperandSpec.first;
1173 operandIdx < varOperandSpec.first + varOperandSpec.second;
1175 frag.regs.push_back(this->getOperand(operandIdx));
1176 if (operandIdx == varOperandSpec.first) {
1177 regTypes.push_back(this->getOperand(operandIdx).
getType());
1180 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1181 regTypes.back(), fragIdx >= 2);
1183 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1187 frags[3].regs.push_back(getSparseMetadata());
1188 frags[4].regs.push_back(getSparsitySelector());
1190 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1191 p <<
" " << frag.operandName;
1197 for (
const auto &frag : frags)
1198 printMmaSpOperand(frag);
1203 for (
int i = 0; i < 3; ++i) {
1208 p <<
") -> " << getResult().getType();
1215 std::optional<MMAIntOverflow> intOverflow,
1216 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1218 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1223 result.addOperands(operandA);
1224 result.addOperands(operandB);
1225 result.addOperands(operandC);
1226 result.addOperands(sparseMetadata);
1227 result.addOperands(sparsitySelector);
1229 if (multiplicandPtxTypes) {
1230 result.addAttribute(
"multiplicandAPtxType",
1231 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1232 result.addAttribute(
"multiplicandBPtxType",
1233 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1235 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1236 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1237 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1238 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1241 if (intOverflow.has_value())
1242 result.addAttribute(
"intOverflowBehavior",
1243 MMAIntOverflowAttr::get(ctx, *intOverflow));
1245 result.addTypes(resultType);
1247 MmaSpOp::getOperandSegmentSizeAttr(),
1249 static_cast<int32_t>(operandB.size()),
1250 static_cast<int32_t>(operandC.size()), 1,
1255 struct MMAOperandFragment {
1256 std::optional<MMATypes> elemtype;
1257 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1258 SmallVector<Type> regTypes;
1262 std::array<MMAOperandFragment, 6> frags;
1267 auto parseMmaSpOperand = [&](StringRef operandName,
1268 MMAOperandFragment &frag) -> LogicalResult {
1279 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1281 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1283 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1285 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1287 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1303 if (operandTypes.size() != 3)
1306 "expected one type for each operand segment but got " +
1307 Twine(operandTypes.size()) +
" types");
1308 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1309 auto &frag = frags[iter.index()];
1310 frag.regTypes.resize(frag.regs.size(), iter.value());
1315 MmaOp::inferOperandMMAType(frag.regTypes[0],
1323 MmaOp::inferOperandMMAType(resultType,
true);
1338 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1339 "multiplicandBPtxType"};
1340 for (
unsigned idx = 0; idx < names.size(); idx++) {
1341 const auto &frag = frags[idx];
1342 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1343 if (!frag.elemtype.has_value() && !attr.has_value()) {
1346 "attribute " + names[idx] +
1347 " is not provided explicitly and cannot be inferred");
1349 if (!attr.has_value())
1351 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1354 result.addTypes(resultType);
1355 if (!namedAttributes.
empty())
1356 result.addAttributes(namedAttributes);
1357 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1359 static_cast<int32_t>(frags[0].regs.size()),
1360 static_cast<int32_t>(frags[1].regs.size()),
1361 static_cast<int32_t>(frags[2].regs.size()),
1368LogicalResult MmaSpOp::verify() {
1370 auto f16Ty = Float16Type::get(context);
1371 auto i32Ty = IntegerType::get(context, 32);
1372 auto f16x2Ty = VectorType::get(2, f16Ty);
1373 auto f32Ty = Float32Type::get(context);
1374 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1375 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1377 auto s32x4StructTy =
1378 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1379 auto f32x8StructTy =
1381 auto f16x2x2StructTy =
1382 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1383 auto f32x4StructTy =
1384 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1385 auto s32x2StructTy =
1386 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1388 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1389 getShapeAttr().getK()};
1395 AllowedShapes allowedShapes;
1396 AllowedTypes expectedA;
1397 AllowedTypes expectedB;
1398 AllowedTypes expectedC;
1403 if (mmaShape[0] == 16) {
1405 Type multiplicandFragType;
1406 switch (*getMultiplicandAPtxType()) {
1407 case MMATypes::tf32:
1409 multiplicandFragType = i32Ty;
1410 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1411 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1413 allowedShapes.push_back({16, 8, 8});
1414 allowedShapes.push_back({16, 8, 16});
1416 case MMATypes::bf16:
1418 multiplicandFragType = i32Ty;
1419 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1420 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1422 allowedShapes.push_back({16, 8, 16});
1423 allowedShapes.push_back({16, 8, 32});
1427 multiplicandFragType = f16x2Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1431 allowedShapes.push_back({16, 8, 16});
1432 allowedShapes.push_back({16, 8, 32});
1438 allowedShapes.push_back({16, 8, 64});
1439 allowedShapes.push_back({16, 8, 128});
1445 allowedShapes.push_back({16, 8, 32});
1446 allowedShapes.push_back({16, 8, 64});
1448 case MMATypes::e4m3:
1449 case MMATypes::e5m2:
1450 case MMATypes::e3m2:
1451 case MMATypes::e2m3:
1452 case MMATypes::e2m1:
1454 multiplicandFragType = i32Ty;
1455 expectedResult.push_back(f16x2x2StructTy);
1456 expectedResult.push_back(f32x4StructTy);
1458 allowedShapes.push_back({16, 8, 64});
1461 return emitError(
"invalid shape or multiplicand type: ")
1462 << getMultiplicandAPtxType().value();
1466 expectedResult.push_back(s32x4StructTy);
1467 expectedC.emplace_back(4, i32Ty);
1468 multiplicandFragType = i32Ty;
1469 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1470 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1472 expectedC.emplace_back(2, f16x2Ty);
1473 expectedC.emplace_back(4, f32Ty);
1475 expectedC.emplace_back(2, f16x2Ty);
1476 expectedC.emplace_back(4, f32Ty);
1481 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1482 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1483 expectedA.emplace_back(unitA, multiplicandFragType);
1484 expectedB.emplace_back(unitB, multiplicandFragType);
1486 if (resultPtxType() != accumPtxType())
1491 if (mmaShape[0] == 8) {
1492 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1493 expectedA.emplace_back(2, f16x2Ty);
1494 expectedB.emplace_back(2, f16x2Ty);
1495 expectedResult.push_back(f16x2x4StructTy);
1496 expectedResult.push_back(f32x8StructTy);
1497 expectedC.emplace_back(4, f16x2Ty);
1498 expectedC.emplace_back(8, f32Ty);
1499 allowedShapes.push_back({8, 8, 4});
1501 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1502 Type f64Ty = Float64Type::get(context);
1503 expectedA.emplace_back(1, f64Ty);
1504 expectedB.emplace_back(1, f64Ty);
1505 expectedC.emplace_back(2, f64Ty);
1506 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1508 allowedShapes.push_back({8, 8, 4});
1511 expectedA.push_back({i32Ty});
1512 expectedB.push_back({i32Ty});
1513 expectedC.push_back({i32Ty, i32Ty});
1514 expectedResult.push_back(s32x2StructTy);
1516 allowedShapes.push_back({8, 8, 32});
1518 allowedShapes.push_back({8, 8, 16});
1522 std::string errorMessage;
1523 llvm::raw_string_ostream errorStream(errorMessage);
1526 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1527 !llvm::is_contained(allowedShapes, mmaShape)) {
1528 errorStream <<
"unimplemented variant for MMA shape <";
1529 llvm::interleaveComma(mmaShape, errorStream);
1535 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1536 for (
const auto &iter : llvm::enumerate(
1538 auto spec = this->getODSOperandIndexAndLength(iter.index());
1540 operand_type_begin() + spec.first +
1542 bool match = llvm::is_contained(iter.value(), operandTySeg);
1545 errorStream <<
"Could not match types for the "
1546 << operandNames[iter.index()]
1547 <<
" operands; expected one of ";
1548 for (
const auto &x : iter.value()) {
1549 errorStream << x.size() <<
"x" << x[0] <<
" ";
1551 errorStream <<
"but got ";
1552 llvm::interleaveComma(operandTySeg, errorStream);
1558 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1559 return expectedResultType == getResult().getType();
1562 <<
"Could not match allowed types for the result; expected one of ";
1563 llvm::interleaveComma(expectedResult, errorStream);
1564 errorStream <<
" but got " << getResult().getType();
1572 if (!getIntOverflowBehavior())
1574 getIntOverflowBehaviorAttrName().strref() +
1579 if (!getSparseMetadata().
getType().isInteger(32)) {
1580 return emitOpError() <<
"sparse metadata must be i32 type";
1584 if (!getSparsitySelector().
getType().isInteger(32)) {
1585 return emitOpError() <<
"sparsity selector must be i32 type";
1597struct MMAOperandFragment {
1598 StringRef operandName;
1599 StringRef ptxTypeAttr;
1600 SmallVector<Value, 4> regs;
1601 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1602 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1609 p <<
" " << name <<
"[";
1628template <
typename Op>
1633 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1634 auto &frag = frags[fragIdx];
1635 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1636 for (
auto operandIdx = varOperandSpec.first;
1637 operandIdx < varOperandSpec.first + varOperandSpec.second;
1639 frag.regs.push_back(op.getOperand(operandIdx));
1640 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1641 regTypes.push_back(op.getOperand(operandIdx).getType());
1645 regTypes.push_back(frag.regs[0].getType());
1647 std::optional<MMATypes> inferredType =
1648 MmaOp::inferOperandMMAType(regTypes.back(),
1651 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1662 auto typeParser = [&]() {
1666 operandTypes.push_back(ty);
1672 if (operandTypes.size() != 3)
1674 "expected exactly 3 types");
1683 if (!attrs.
get(
"multiplicandAPtxType")) {
1684 if (
auto inferredType =
1685 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1686 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1689 if (!attrs.
get(
"multiplicandBPtxType")) {
1690 if (
auto inferredType =
1691 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1692 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1698template <
typename OpType>
1701 ScaleVecSize scaleVecSize,
1702 BlockScaleFormat blockScaleFormat,
1703 MMABlockScaleKind kind) {
1705 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1706 properties.setShape(
1708 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1709 properties.setBlockScaleFormat(
1710 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1711 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1718 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1719 if (multiplicandPtxTypes) {
1720 result.addAttribute(
"multiplicandAPtxType",
1721 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1722 result.addAttribute(
"multiplicandBPtxType",
1723 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1725 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1726 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1727 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1728 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1733template <
typename OpTy>
1735 return *MmaOp::inferOperandMMAType(
1736 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1746 std::array<MMAOperandFragment, 3> frags{
1747 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1748 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1749 MMAOperandFragment(
"C",
"")};
1751 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1756 for (
const auto &frag : frags)
1761 {getScaleAData(), getByteIdA(), getThreadIdA()});
1763 {getScaleBData(), getByteIdB(), getThreadIdB()});
1770 frags[1].regs[0].getType(),
1771 frags[2].regs[0].getType()},
1777ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1779 struct LocalOperandFragment {
1780 std::optional<MMATypes> elemtype;
1781 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1785 std::array<LocalOperandFragment, 3> frags;
1814 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1815 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1818 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1828 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1838 result.addAttributes(namedAttributes);
1842 result.addTypes(resultTypes);
1843 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1845 static_cast<int32_t>(frags[0].regs.size()),
1846 static_cast<int32_t>(frags[1].regs.size()),
1847 static_cast<int32_t>(frags[2].regs.size()),
1858void MmaBlockScaleOp::build(
1863 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1864 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1865 MMABlockScaleKind kind) {
1866 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1869 blockScaleFormat, kind);
1871 result.addOperands(operandA);
1872 result.addOperands(operandB);
1873 result.addOperands(operandC);
1875 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1878 multiplicandPtxTypes);
1880 result.addTypes(resultType);
1881 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1883 static_cast<int32_t>(operandA.size()),
1884 static_cast<int32_t>(operandB.size()),
1885 static_cast<int32_t>(operandC.size()),
1897 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1901 for (
Value operand : curOp.getOperandA())
1903 for (
Value operand : curOp.getOperandB())
1905 for (
Value operand : curOp.getOperandC())
1909 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1910 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1911 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1912 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1913 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1914 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1916 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1917 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1918 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1920 curOp.getBlockScaleFormat(), curOp.getKind());
1922 return {intId, args};
1925LogicalResult MmaBlockScaleOp::verify() {
1931 if (m == 16 && n == 8 && k == 64) {
1932 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1933 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1935 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1936 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1937 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1939 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1940 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1942 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1943 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1944 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1945 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1946 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1947 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1948 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1950 "attributes for mma.m16n8k64.mxf4nvf4");
1954 }
else if (m == 16 && n == 8 && k == 32) {
1955 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1956 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1957 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1959 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
1960 "attributes for mma.m16n8k32");
1973 std::array<MMAOperandFragment, 3> frags{
1974 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1975 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1976 MMAOperandFragment(
"C",
"")};
1978 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1983 for (
const auto &frag : frags)
1992 {getScaleAData(), getByteIdA(), getThreadIdA()});
1994 {getScaleBData(), getByteIdB(), getThreadIdB()});
2001 frags[1].regs[0].getType(),
2002 frags[2].regs[0].getType()},
2008ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
2010 struct LocalOperandFragment {
2011 std::optional<MMATypes> elemtype;
2012 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2016 std::array<LocalOperandFragment, 3> frags;
2052 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2053 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2056 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2065 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2078 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2088 result.addAttributes(namedAttributes);
2093 if (!
result.attributes.get(
"orderedMetadata"))
2096 result.addTypes(resultTypes);
2097 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2099 static_cast<int32_t>(frags[0].regs.size()),
2100 static_cast<int32_t>(frags[1].regs.size()),
2101 static_cast<int32_t>(frags[2].regs.size()),
2114void MmaSpBlockScaleOp::build(
2120 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2121 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2122 MMABlockScaleKind kind) {
2123 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2126 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2129 result.addOperands(operandA);
2130 result.addOperands(operandB);
2131 result.addOperands(operandC);
2132 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2133 threadIdA, scaleBData, byteIdB, threadIdB});
2136 multiplicandPtxTypes);
2138 result.addTypes(resultType);
2139 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2141 static_cast<int32_t>(operandA.size()),
2142 static_cast<int32_t>(operandB.size()),
2143 static_cast<int32_t>(operandC.size()),
2157 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2161 for (
Value operand : curOp.getOperandA())
2163 for (
Value operand : curOp.getOperandB())
2165 for (
Value operand : curOp.getOperandC())
2169 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2170 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2173 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2174 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2175 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2176 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2177 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2178 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2180 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2181 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2182 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2184 curOp.getBlockScaleFormat(), curOp.getKind());
2186 return {intId, args};
2189LogicalResult MmaSpBlockScaleOp::verify() {
2191 if (!getOrderedMetadata()) {
2192 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2200 if (m == 16 && n == 8 && k == 128) {
2201 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2202 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2204 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2205 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2206 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2208 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2209 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2211 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2212 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2213 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2214 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2215 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2216 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2217 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2219 "attributes for mma.m16n8k128.mxf4nvf4");
2223 }
else if (m == 16 && n == 8 && k == 64) {
2224 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2225 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2226 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2228 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2229 "attributes for mma.m16n8k64");
2236LogicalResult ShflOp::verify() {
2237 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2239 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2240 Type actualType) -> LogicalResult {
2241 return emitOpError(
"expected " + desc +
" to be of type ")
2242 << expectedType <<
" but got " << actualType <<
" instead";
2245 if (returnStructType) {
2246 if (!getReturnValueAndIsValid())
2247 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2248 "specified when the return type is a struct type");
2250 if (returnStructType.getBody().size() != 2)
2251 return emitOpError(
"expected return type to be a two-element struct");
2254 auto resultType = returnStruct[0];
2255 if (resultType != getVal().
getType())
2256 return verifyTypeError(
"first element in the returned struct",
2257 getVal().
getType(), resultType);
2259 auto predicateType = returnStruct[1];
2260 if (!predicateType.isInteger(1))
2261 return verifyTypeError(
"second element in the returned struct",
2265 if (getReturnValueAndIsValid())
2266 return emitOpError(
"expected return type to be a two-element struct");
2269 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2275 NVVM::MMAFrag frag,
int nRow,
2278 unsigned numberElements = 0;
2281 Type f16x2 = VectorType::get(2, builder.getF16Type());
2282 if (type == NVVM::MMATypes::f16) {
2283 elementType = f16x2;
2284 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2288 }
else if (type == NVVM::MMATypes::f32) {
2289 elementType = builder.getF32Type();
2291 }
else if (type == NVVM::MMATypes::f64) {
2292 elementType = builder.getF64Type();
2293 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2297 }
else if (type == NVVM::MMATypes::tf32) {
2298 elementType = builder.getI32Type();
2300 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2301 elementType = builder.getI32Type();
2302 int parallelSize = 0;
2303 if (frag == NVVM::MMAFrag::a)
2304 parallelSize = nRow;
2305 if (frag == NVVM::MMAFrag::b)
2306 parallelSize = nCol;
2309 if (parallelSize == 16)
2312 else if (parallelSize == 8)
2314 else if (parallelSize == 32)
2316 }
else if (type == NVVM::MMATypes::s32) {
2317 elementType = builder.getI32Type();
2320 assert(numberElements != 0 && elementType !=
nullptr);
2321 return std::make_pair(elementType, numberElements);
2324static std::pair<mlir::Type, unsigned>
2328 if (frag == NVVM::MMAFrag::a) {
2331 }
else if (frag == NVVM::MMAFrag::b) {
2338 assert(nRow && nCol);
2342LogicalResult NVVM::WMMALoadOp::verify() {
2343 unsigned addressSpace =
2344 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2345 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2346 addressSpace != NVVMMemorySpace::Shared)
2347 return emitOpError(
"expected source pointer in memory "
2350 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2351 getEltype(), getFrag()) == 0)
2352 return emitOpError() <<
"invalid attribute combination";
2357 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2359 return emitOpError(
"expected destination type to be f64");
2363 Type dstType = LLVM::LLVMStructType::getLiteral(
2366 return emitOpError(
"expected destination type is a structure of ")
2367 << typeInfo.second <<
" elements of type " << typeInfo.first;
2371LogicalResult NVVM::WMMAStoreOp::verify() {
2372 unsigned addressSpace =
2373 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2374 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2375 addressSpace != NVVMMemorySpace::Shared)
2376 return emitOpError(
"expected operands to be a source pointer in memory "
2379 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2381 return emitOpError() <<
"invalid attribute combination";
2384 if (getArgs().size() != typeInfo.second)
2385 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2386 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2387 return operands.
getType() != typeInfo.first;
2389 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2393LogicalResult NVVM::WMMAMmaOp::verify() {
2394 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2395 getLayoutB(), getEltypeA(),
2397 return emitOpError() <<
"invalid attribute combination";
2405 arguments.append(typeInfoA.second, typeInfoA.first);
2406 arguments.append(typeInfoB.second, typeInfoB.first);
2407 arguments.append(typeInfoC.second, typeInfoC.first);
2408 unsigned numArgs = arguments.size();
2409 if (getArgs().size() != numArgs)
2410 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2411 for (
unsigned i = 0; i < numArgs; i++) {
2412 if (getArgs()[i].
getType() != arguments[i])
2413 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2416 Type dstType = LLVM::LLVMStructType::getLiteral(
2419 return emitOpError(
"expected destination type is a structure of ")
2420 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2424LogicalResult NVVM::LdMatrixOp::verify() {
2426 if (m == 8 && n == 8) {
2427 if (num != 1 && num != 2 && num != 4) {
2428 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2431 if (getEltType() != LdStMatrixEltType::B16) {
2432 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2434 }
else if (m == 8 && n == 16) {
2435 if (num != 1 && num != 2 && num != 4) {
2436 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2439 if (getLayout() != MMALayout::row) {
2440 return emitOpError(
"expected layout to be row for 8x16 matrix");
2442 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2443 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2444 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2445 "b8x16.b6x16_p32 for 8x16 matrix");
2447 }
else if (m == 16 && n == 16) {
2448 if (num != 1 && num != 2) {
2449 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2452 if (getLayout() != MMALayout::col) {
2453 return emitOpError(
"expected layout to be col for 16x16 matrix");
2455 if (getEltType() != LdStMatrixEltType::B8 &&
2456 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2457 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2458 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2459 "b8x16.b6x16_p32 for 16x16 matrix");
2462 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2466 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2467 if (numElements == 1 &&
getType() != i32)
2468 return emitOpError(
"expected destination type is i32");
2469 if (numElements == 2 || numElements == 4) {
2470 Type dstType = LLVM::LLVMStructType::getLiteral(
2473 return emitOpError(
"expected destination type is a structure of ")
2474 << numElements <<
" elements of type i32";
2480LogicalResult NVVM::StMatrixOp::verify() {
2481 int numMatrix = getSources().size();
2482 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2483 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2486 if (m == 8 && n == 8) {
2487 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2488 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2490 }
else if (m == 16 && n == 8) {
2491 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2492 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2494 if (getLayout() != NVVM::MMALayout::col) {
2495 return emitOpError(
"expected layout to be col for 16x8 matrix");
2498 return emitOpError(
"expected shape to be 8x8 or 16x8");
2505 if (typeA == NVVM::WGMMATypes::tf32)
2507 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2509 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2511 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2513 if (typeA == NVVM::WGMMATypes::b1)
2519 NVVM::WGMMATypes typeA,
2520 NVVM::WGMMATypes typeB) {
2522 case NVVM::WGMMATypes::f16:
2523 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2524 typeB == NVVM::WGMMATypes::f16)
2527 case NVVM::WGMMATypes::tf32:
2528 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2531 case NVVM::WGMMATypes::u8:
2532 case NVVM::WGMMATypes::s8:
2533 if (typeD == NVVM::WGMMATypes::s32 &&
2534 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2537 case NVVM::WGMMATypes::b1:
2538 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2541 case NVVM::WGMMATypes::bf16:
2542 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2543 typeB == NVVM::WGMMATypes::bf16)
2546 case NVVM::WGMMATypes::e4m3:
2547 case NVVM::WGMMATypes::e5m2:
2548 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2549 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2552 case WGMMATypes::f32:
2553 case WGMMATypes::s32:
2554 llvm_unreachable(
"unsupported input types");
2562 72, 80, 88, 96, 104, 112, 120, 128,
2563 136, 144, 152, 160, 168, 176, 184, 192,
2564 200, 208, 216, 224, 232, 240, 248, 256};
2566 80, 96, 112, 128, 144, 160,
2567 176, 192, 208, 224, 240, 256};
2569 case WGMMATypes::f16:
2570 case WGMMATypes::tf32:
2571 case WGMMATypes::bf16:
2572 case WGMMATypes::e4m3:
2573 case WGMMATypes::e5m2:
2574 if (llvm::is_contained(allowedN, sizeN))
2577 case WGMMATypes::u8:
2578 case WGMMATypes::s8:
2579 case WGMMATypes::b1:
2580 if (llvm::is_contained(allowedNshort, sizeN))
2583 case WGMMATypes::f32:
2584 case WGMMATypes::s32:
2585 llvm_unreachable(
"unsupported input types");
2591LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2592 Value outValue = getResults();
2593 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2595 return emitOpError() <<
"expected results to be struct";
2596 int outputSize = stype.getBody().size();
2597 WGMMATypes typeD = getTypeD();
2598 WGMMATypes typeA = getTypeA();
2599 WGMMATypes typeB = getTypeB();
2601 for (
Type t : stype.getBody()) {
2602 if (t != stype.getBody().front())
2604 <<
"all elements in struct must be same type but there is " << t;
2607 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2608 typeD != WGMMATypes::s32) {
2609 return emitOpError() <<
"does not support the given output type " << typeD;
2611 if (typeD == WGMMATypes::s32 &&
2612 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2613 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2617 return emitOpError() << typeD <<
" += " << typeA <<
" * " << typeB
2618 <<
", it is not supported.";
2628 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2629 <<
" for input type " << typeA;
2633 return emitOpError() <<
"has input type " << typeA <<
" n is set to "
2634 <<
getShape().getN() <<
", it is not supported.";
2641 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2642 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2643 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2645 <<
"given layouts layout_a = " << getLayoutA()
2646 <<
" and layout_b = " << getLayoutB() <<
" for input types " << typeA
2648 <<
" requires transpose. However, this is only supported for: "
2649 << MMATypes::f16 <<
" and " << MMATypes::bf16;
2653 int expectedOutput = 0;
2654 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2655 expectedOutput =
getShape().getN() / 2;
2656 if (typeD == WGMMATypes::f16)
2657 expectedOutput =
getShape().getN() / 4;
2658 if (outputSize != expectedOutput) {
2659 return emitOpError() <<
"results " << expectedOutput
2660 <<
", however output struct has " << outputSize
2664 if (typeD != WGMMATypes::s32 &&
2665 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2666 NVVM::MMAIntOverflow::satfinite) {
2668 <<
" `satfinite` can be only used with s32 accumulator, however "
2669 "the current accumulator is "
2676std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2679 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2681 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2683 int expectedOutputRegisters = 0;
2684 if (getTypeD() == WGMMATypes::f16)
2685 expectedOutputRegisters =
getShape().getN() / 4;
2687 expectedOutputRegisters =
getShape().getN() / 2;
2690 llvm::raw_string_ostream ss(ptx);
2695 << ((expectedOutputRegisters * 2) + 2)
2697 "wgmma.mma_async.sync.aligned.m"
2698 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"." << getTypeA()
2699 <<
"." << getTypeB();
2700 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2701 NVVM::MMAIntOverflow::satfinite)
2705 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2706 ss <<
"$" << regCnt;
2707 if (regCnt != expectedOutputRegisters - 1)
2713 regCnt = (regCnt * 2);
2714 ss <<
" $" << (regCnt) <<
","
2715 <<
" $" << (regCnt + 1) <<
","
2717 if (getTypeD() != WGMMATypes::s32) {
2718 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2722 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2729bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2733 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2740 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2742 if (getTypeD() != WGMMATypes::s32) {
2743 asmValues.push_back(
2744 {makeConstantI32(rewriter,
2745 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2747 asmValues.push_back(
2748 {makeConstantI32(rewriter,
2749 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2753 asmValues.push_back(
2754 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2756 asmValues.push_back(
2757 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2763LogicalResult NVVM::FenceProxyOp::verify() {
2764 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2765 return emitOpError() <<
"async_shared fence requires space attribute";
2767 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2768 return emitOpError() <<
"only async_shared fence can have space attribute";
2773LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2774 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2775 return emitOpError(
"uni-directional proxies only support generic for "
2776 "from_proxy attribute");
2778 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2779 return emitOpError(
"uni-directional proxies only support tensormap "
2780 "for to_proxy attribute");
2784LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2785 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2786 return emitOpError(
"uni-directional proxies only support generic for "
2787 "from_proxy attribute");
2789 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2790 return emitOpError(
"uni-directional proxies only support tensormap "
2791 "for to_proxy attribute");
2795LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2796 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2797 return emitOpError(
"only generic is support for from_proxy attribute");
2799 if (getToProxy() != NVVM::ProxyKind::async)
2800 return emitOpError(
"only async is supported for to_proxy attribute");
2804LogicalResult NVVM::SetMaxRegisterOp::verify() {
2805 if (getRegCount() % 8)
2806 return emitOpError(
"new register size must be multiple of 8");
2807 if (getRegCount() < 24 || getRegCount() > 256)
2808 return emitOpError(
"new register size must be in between 24 to 256");
2812LogicalResult NVVM::BarrierOp::verify() {
2813 if (getNumberOfThreads() && !getBarrierId())
2815 "barrier id is missing, it should be set between 0 to 15");
2817 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
2818 return emitOpError(
"reduction are only available when id is 0");
2822 return emitOpError(
"reduction predicate and reduction operation must be "
2823 "specified together");
2828LogicalResult NVVM::Tcgen05CpOp::verify() {
2829 auto mc = getMulticast();
2831 using SH = Tcgen05CpShape;
2832 using MC = Tcgen05CpMulticast;
2834 case SH::SHAPE_128x256b:
2835 case SH::SHAPE_128x128b:
2836 case SH::SHAPE_4x256b:
2838 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2840 case SH::SHAPE_64x128b:
2841 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2842 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2843 "warpx2_02_13 for tcgen05.cp Op");
2845 case SH::SHAPE_32x128b:
2846 if (mc != MC::WARPX4)
2848 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2854LogicalResult NVVM::MatchSyncOp::verify() {
2855 if (getKind() == NVVM::MatchSyncKind::all) {
2856 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2857 if (!type || type.getBody().size() != 2 ||
2858 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2859 return emitOpError(
"match.sync 'all' returns a two element struct with "
2860 "first element as i32 and second element as i1");
2863 if (!
getType().isInteger(32)) {
2864 return emitOpError(
"match.sync 'any' returns an i32");
2870LogicalResult NVVM::VoteSyncOp::verify() {
2871 if (getKind() == NVVM::VoteSyncKind::ballot) {
2872 if (!
getType().isInteger(32)) {
2873 return emitOpError(
"vote.sync 'ballot' returns an i32");
2876 if (!
getType().isInteger(1)) {
2877 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
2883LogicalResult NVVM::PrefetchOp::verify() {
2884 using MemSpace = NVVM::NVVMMemorySpace;
2885 using CacheLevel = NVVM::PrefetchCacheLevel;
2887 unsigned addressSpace =
2888 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
2889 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2890 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2892 if (getTensormap() && cacheLevel)
2893 return emitOpError(
"cannot specify both tensormap and cache level");
2895 if (getTensormap()) {
2896 if (addressSpace != MemSpace::Generic &&
2897 addressSpace != MemSpace::Constant) {
2899 "prefetch tensormap requires a generic or constant pointer");
2902 if (evictPriority) {
2904 "prefetch tensormap does not support eviction priority");
2907 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2909 "in_param_space can only be specified for a generic pointer");
2912 }
else if (cacheLevel) {
2913 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2914 addressSpace != MemSpace::Local) {
2915 return emitOpError(
"prefetch to cache level requires a generic, global, "
2916 "or local pointer");
2920 if (*cacheLevel != CacheLevel::L1) {
2922 "unsupported cache level, the only supported uniform "
2923 "cache level is L1");
2926 if (addressSpace != MemSpace::Generic) {
2928 "prefetch to uniform cache requires a generic pointer");
2932 if (evictPriority) {
2933 if (*cacheLevel != CacheLevel::L2)
2935 "cache eviction priority supported only for cache level L2");
2937 if (addressSpace != MemSpace::Global)
2938 return emitOpError(
"cache eviction priority requires a global pointer");
2940 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2941 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2943 "unsupported cache eviction priority, only evict_last and "
2944 "evict_normal are supported");
2948 return emitOpError(
"predicate supported only on prefetch tensormap");
2952 "requires specification of either cache level or tensormap");
2958LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2959 switch (getQueryType()) {
2960 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2962 return emitOpError(
"is_canceled query type returns an i1");
2964 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2965 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2966 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2967 if (!
getType().isInteger(32)) {
2968 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
2969 "get_first_cta_id_z query types return an i32");
2976LogicalResult NVVM::ReduxOp::verify() {
2979 if (!reduxType.
isF32()) {
2981 return emitOpError(
"abs attribute is supported only for f32 type");
2983 return emitOpError(
"nan attribute is supported only for f32 type");
2986 NVVM::ReductionKind kind = getKind();
2988 case NVVM::ReductionKind::ADD:
2989 case NVVM::ReductionKind::AND:
2990 case NVVM::ReductionKind::OR:
2991 case NVVM::ReductionKind::XOR:
2992 case NVVM::ReductionKind::MAX:
2993 case NVVM::ReductionKind::MIN:
2994 case NVVM::ReductionKind::UMAX:
2995 case NVVM::ReductionKind::UMIN:
2998 << kind <<
"' reduction kind unsupported with " << reduxType
2999 <<
" type. Only supported type is 'i32'.";
3001 case NVVM::ReductionKind::FMIN:
3002 case NVVM::ReductionKind::FMAX:
3003 if (!reduxType.isF32())
3005 << kind <<
"' reduction kind unsupported with " << reduxType
3006 <<
" type. Only supported type is 'f32'.";
3013LogicalResult NVVM::TensormapReplaceOp::verify() {
3014 auto ord = getOrd();
3015 Value newVal = getNewValue();
3016 auto newValAttr = getNewValueAttr();
3017 auto fieldName = stringifyEnum(getField());
3019 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3020 NVVM::TensormapField::GLOBAL_DIM,
3021 NVVM::TensormapField::GLOBAL_STRIDE,
3022 NVVM::TensormapField::ELEMENT_STRIDE},
3024 return emitOpError(
"ordinal is not supported for ")
3025 << fieldName <<
" field";
3027 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3028 return llvm::Twine(
"new_value must be specified and must be an " + type +
3029 " for " + llvm::Twine(fieldName) +
" field")
3033 auto invalidNewValAttr = [&]() -> std::string {
3034 return (llvm::Twine(
3035 "new_value_attr must be specified and must be a valid ") +
3036 llvm::Twine(fieldName) +
" attribute for " + fieldName +
" field")
3040 switch (getField()) {
3041 case NVVM::TensormapField::GLOBAL_ADDRESS:
3045 case NVVM::TensormapField::RANK:
3049 case NVVM::TensormapField::GLOBAL_STRIDE:
3051 return emitOpError(
"ordinal is required for global_stride field");
3055 case NVVM::TensormapField::BOX_DIM:
3056 case NVVM::TensormapField::GLOBAL_DIM:
3057 case NVVM::TensormapField::ELEMENT_STRIDE:
3060 << stringifyEnum(getField()) <<
" field";
3064 case NVVM::TensormapField::ELEMTYPE:
3065 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3068 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3069 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3072 case NVVM::TensormapField::SWIZZLE_MODE:
3073 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3076 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3077 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3080 case NVVM::TensormapField::FILL_MODE:
3081 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3089template <
typename OpType>
3091 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3092 mlir::NVVM::SaturationMode satMode = op.getSat();
3093 bool isFTZ = op.getFtz();
3096 mlir::Type opBaseType = isa<VectorType>(opType)
3097 ? cast<VectorType>(opType).getElementType()
3100 if (opBaseType.
isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3101 return op.emitOpError(
"FTZ and saturation are not supported for "
3102 "additions/subtractions involving f64 type");
3104 if (opBaseType.
isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3105 rndMode == NVVM::FPRoundingMode::NONE))
3106 return op.emitOpError(
"only RN rounding mode is supported for f16 and "
3107 "vector<2xf16> additions/subtractions");
3109 if (opBaseType.
isBF16()) {
3110 if (rndMode != NVVM::FPRoundingMode::RN &&
3111 rndMode != NVVM::FPRoundingMode::NONE)
3112 return op.emitOpError(
"only RN rounding mode is supported for bf16 and "
3113 "vector<2xbf16> additions/subtractions");
3114 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3115 return op.emitOpError(
"FTZ and saturation are not supported for bf16 and "
3116 "vector<2xbf16> additions/subtractions");
3123 if (opBaseType.
isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3124 return op.emitOpError(
"FTZ with no saturation is not supported for f16 and "
3125 "vector<2xf16> additions/subtractions");
3134LogicalResult NVVM::FmaOp::verify() {
3135 auto opType = getRes().getType();
3136 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3137 mlir::NVVM::SaturationMode satMode = getSat();
3138 bool isFTZ = getFtz();
3139 bool isRelu = getRelu();
3140 bool hasOOB = getOob();
3142 auto getBaseFType = [](
Type type) ->
Type {
3143 if (isa<VectorType>(type))
3144 return cast<VectorType>(type).getElementType();
3148 auto opBaseType = getBaseFType(opType);
3150 if (rndMode == NVVM::FPRoundingMode::NONE)
3151 return emitOpError(
"rounding mode must be specified");
3153 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3154 return emitOpError(
"relu and saturation are not supported together");
3156 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3157 return emitOpError(
"oob is not supported with saturation or FTZ");
3159 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3160 return emitOpError(
"relu and oob are only supported for f16 and bf16");
3162 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3163 return emitOpError(
"FTZ and saturation are not supported for f64 type");
3165 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3167 "only RN rounding mode is supported for f16 and vector<2xf16>");
3169 if (opBaseType.isBF16()) {
3170 if (rndMode != NVVM::FPRoundingMode::RN)
3172 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3173 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3175 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3187 unsigned sizeInBits,
3189 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3191 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3192 if (mask != 0xffffffffu)
3193 field = builder.CreateAnd(field, builder.getInt32(mask));
3195 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3196 field = builder.CreateShl(field, start);
3198 return builder.CreateOr(
result, field);
3201void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3203 llvm::IRBuilderBase &builder) {
3204 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3205 llvm::Value *smemDesc = builder.getInt64(0);
3210 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3212 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3218 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3222 mt.
mapValue(thisOp.getRes()) = smemDesc;
3229std::string NVVM::MBarrierInitOp::getPtx() {
3231 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3232 : std::string(
"mbarrier.init.b64 [%0], %1;");
3235std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3238 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3239 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3242std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3244 llvm::StringRef space = isShared ?
".shared" :
"";
3246 return llvm::formatv(
"{\n\t"
3247 ".reg .pred P1; \n\t"
3249 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3250 "@P1 bra.uni DONE; \n\t"
3251 "bra.uni LAB_WAIT; \n\t"
3268 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3271 op.getRnd(), op.getSat(), op.getFtz());
3287 auto thisOp = cast<NVVM::BarrierOp>(op);
3288 llvm::Value *barrierId = thisOp.getBarrierId()
3290 : builder.getInt32(0);
3291 llvm::Intrinsic::ID id;
3293 if (thisOp.getNumberOfThreads()) {
3294 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3295 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3296 }
else if (thisOp.getReductionOp()) {
3297 switch (*thisOp.getReductionOp()) {
3298 case NVVM::BarrierReduction::AND:
3299 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3301 case NVVM::BarrierReduction::OR:
3302 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3304 case NVVM::BarrierReduction::POPC:
3305 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3308 args.push_back(builder.CreateICmpNE(
3309 mt.
lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3311 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3314 return {id, std::move(args)};
3319 llvm::IRBuilderBase &builder) {
3320 auto thisOp = cast<NVVM::PMEventOp>(op);
3324 llvm::Value *maskVal;
3325 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3326 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3327 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3330 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3333 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3338 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3340 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3341 : llvm::Intrinsic::nvvm_mbarrier_init;
3346 args.push_back(mt.
lookupValue(thisOp.getCount()));
3348 return {id, std::move(args)};
3353 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3355 llvm::Intrinsic::ID
id = isShared
3356 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3357 : llvm::Intrinsic::nvvm_mbarrier_inval;
3364 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3367 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3370 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3372 static constexpr llvm::Intrinsic::ID IDs[] = {
3373 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3374 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3375 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3376 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3381 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3383 return {IDs[
index], std::move(args)};
3388 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3391 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3394 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3396 static constexpr llvm::Intrinsic::ID IDs[] = {
3397 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3398 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3399 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3400 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3405 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3407 return {IDs[
index], std::move(args)};
3412 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3415 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3418 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3420 static constexpr llvm::Intrinsic::ID IDs[] = {
3421 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3422 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3423 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3424 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3425 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3426 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3427 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3428 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3430 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3431 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3435 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3442 bool hasCount =
static_cast<bool>(thisOp.getCount());
3444 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3445 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3449 llvm::Value *count =
3451 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3452 return {id, {mbar, count}};
3457 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3460 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3463 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3465 static constexpr llvm::Intrinsic::ID IDs[] = {
3466 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3467 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3468 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3469 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3470 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3471 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3473 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3475 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3477 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3478 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3482 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3488 bool hasCount =
static_cast<bool>(thisOp.getCount());
3489 llvm::Value *count =
3491 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3493 return {id, {mbar, count}};
3496bool MBarrierArriveExpectTxOp::getAsmValues(
3503 for (
auto val : getOperands())
3511 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3514 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3517 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3520 static constexpr llvm::Intrinsic::ID IDs[] = {
3521 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3522 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3523 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3524 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3525 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3526 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3527 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3528 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3531 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3534 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3535 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3540 return {id, {mbar, txcount}};
3545 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3548 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3551 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3554 static constexpr llvm::Intrinsic::ID IDs[] = {
3555 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3556 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3557 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3558 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3559 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3560 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3561 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3562 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3563 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3565 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3568 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3569 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3574 return {id, {mbar, txcount}};
3579 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3581 llvm::Intrinsic::ID
id =
3582 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3583 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3587 args.push_back(mt.
lookupValue(thisOp.getCount()));
3589 return {id, std::move(args)};
3594 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3596 llvm::Intrinsic::ID
id =
3597 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3598 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3602 args.push_back(mt.
lookupValue(thisOp.getCount()));
3604 return {id, std::move(args)};
3609 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3610 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3611 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3614 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3617 static constexpr llvm::Intrinsic::ID IDs[] = {
3618 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3619 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3620 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3621 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3622 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3623 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3624 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3625 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3626 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3628 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3631 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3632 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3637 return {id, {mbar, input}};
3642 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3643 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3644 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3645 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3649 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3650 (isPhaseParity ? 1 : 0);
3653 static constexpr llvm::Intrinsic::ID IDs[] = {
3654 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3655 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3656 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3657 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3658 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3659 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3660 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3661 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3662 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3663 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3664 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3665 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3666 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3667 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3668 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3669 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3670 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3672 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3675 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3682 args.push_back(mbar);
3683 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
3685 args.push_back(mt.
lookupValue(thisOp.getTicks()));
3687 return {id, std::move(args)};
3692 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3695 llvm::Intrinsic::ID id;
3696 if (thisOp.getNoinc()) {
3697 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3698 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3700 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3701 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3707#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3708 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3710#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3711 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3716 llvm::Intrinsic::ID id;
3718 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3719 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
3720 switch (cpAsyncOp.getSize()) {
3728 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3733 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
3737 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
3738 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
3740 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
3747 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3749 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3752 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3756 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3757 llvm::Value *i64Unused =
3758 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3759 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3760 args.push_back(builder.getInt1(hasCacheHint));
3762 return {id, std::move(args)};
3767 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3771 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3773 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3777 mlir::Value multicastMask = thisOp.getMulticastMask();
3778 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
3781 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3782 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
3788 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3789 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3790 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3794 args.push_back(builder.getInt1(hasMulticastMask));
3795 args.push_back(builder.getInt1(hasCacheHint));
3797 llvm::Intrinsic::ID
id =
3799 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3800 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3802 return {id, std::move(args)};
3807 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3809 llvm::Intrinsic::ID
id =
3810 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3813 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3814 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3818 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3819 llvm::Value *i64Unused =
3820 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3821 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3822 args.push_back(builder.getInt1(hasCacheHint));
3825 if (
mlir::Value byteMask = thisOp.getByteMask()) {
3827 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3830 return {id, std::move(args)};
3833bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3840 for (
auto val : getOperands())
3847CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3849 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3850 const bool isCTAOnly = thisOp.getIsCTAOnly();
3854 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3856 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3866 const bool hasMC =
static_cast<bool>(mcMask);
3867 llvm::Value *i16Zero =
3868 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
3872 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3873 llvm::Value *i64Zero =
3874 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3880 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
3882 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
3886 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
3887 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3888 args.push_back(builder.getInt1(hasMC));
3889 args.push_back(builder.getInt1(hasCacheHint));
3893 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3894 args.push_back(builder.getInt1(hasCacheHint));
3897 constexpr size_t numDims = 5;
3898 constexpr size_t numModes = 5;
3899 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3900 using TableTy = std::array<rowTy, numModes>;
3901 static constexpr TableTy IDTable{
3902 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3910 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3913 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3914 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3916 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3917 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3918 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3920 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3922 static constexpr TableTy IDTableCTA{
3924 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3925 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3926 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3927 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3928 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3930 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3931 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3932 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3934 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3935 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3936 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3940 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3942 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3945 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3946 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3947 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3948 size_t mode =
static_cast<size_t>(thisOp.getMode());
3949 size_t dim = thisOp.getCoordinates().size();
3950 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3952 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3954 return {id, std::move(args)};
3959 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3963 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3965 for (
auto v : thisOp.getCoordinates())
3967 for (
auto v : thisOp.getIm2colOffsets())
3971 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3972 llvm::Value *i64Unused =
3973 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3974 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3975 args.push_back(builder.getInt1(hasCacheHint));
3977 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3978 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3979 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3980 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3981 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3982 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3983 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3985 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3986 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3987 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3989 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3990 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3991 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3994 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3995 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3996 {NI, NI, NI, NI, NI,
3997 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3999 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4000 "TMALoadModes must match number of rows in IDTable");
4001 size_t mode =
static_cast<size_t>(thisOp.getMode());
4002 size_t dim = thisOp.getCoordinates().size();
4003 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4004 if (
id == llvm::Intrinsic::not_intrinsic)
4005 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4007 return {id, std::move(args)};
4011CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4013 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4017 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4018 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4020 for (
auto v : thisOp.getCoordinates())
4024 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4025 llvm::Value *i64Unused =
4026 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4027 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4028 args.push_back(builder.getInt1(hasCacheHint));
4030 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4031 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4032 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4033 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4034 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4035 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4036 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4037 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4038 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4039 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4040 {NI, NI, NI, NI, NI,
4041 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4043 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4044 "TMAStoreModes must match number of rows in IDTable");
4045 size_t mode =
static_cast<size_t>(thisOp.getMode());
4046 size_t dim = thisOp.getCoordinates().size();
4047 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4048 if (
id == llvm::Intrinsic::not_intrinsic)
4050 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4052 return {id, std::move(args)};
4057 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4065 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4066 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4068 for (
Value v : thisOp.getCoordinates())
4072 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4073 llvm::Value *i64ZeroValue =
4074 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4075 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
4076 args.push_back(builder.getInt1(hasCacheHint));
4078 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
4080 constexpr unsigned numRedKinds = 8;
4081 constexpr unsigned numLayouts = 2;
4082 constexpr unsigned maxDim = 5;
4083 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4084 using layoutTable = std::array<row, numLayouts>;
4085 using fullTable = std::array<layoutTable, numRedKinds>;
4086 static constexpr fullTable IDTable{
4089 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4090 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4091 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4092 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4093 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4095 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4096 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4097 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4100 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4101 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4102 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4104 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4107 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4111 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4115 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4119 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4122 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4123 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4126 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4129 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4133 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4137 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4144 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4147 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4148 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4150 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4151 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4152 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4155 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4156 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4157 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4158 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4159 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4161 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4162 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4163 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4166 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4167 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4168 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4169 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4170 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4172 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4173 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4175 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4177 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4178 "TMAReduxKinds must match number of rows in IDTable");
4180 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
4181 size_t mode =
static_cast<size_t>(thisOp.getMode());
4182 size_t dim = thisOp.getCoordinates().size();
4184 assert(redKind < IDTable.size() &&
4185 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4186 assert(mode < IDTable[redKind].size() &&
4187 "Invalid mode for CpAsyncBulkTensorReduceOp");
4188 assert(dim < IDTable[redKind][mode].size() &&
4189 "Invalid dim for CpAsyncBulkTensorReduceOp");
4191 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4194 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4196 return {intrinsicID, std::move(args)};
4201#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4202 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4203 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4205#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4206 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4207 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4210ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4211 NVVM::SaturationMode sat,
bool hasRelu) {
4212 using RndMode = NVVM::FPRoundingMode;
4213 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4222 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4227ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4229 llvm::IRBuilderBase &builder) {
4234 bool hasRelu = op.getRelu();
4236 llvm::Intrinsic::ID intId =
4237 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4238 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4240 return {intId, std::move(args)};
4243#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4244 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4245 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4247llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4250 .Case([&](mlir::Float6E2M3FNType) {
4253 .Case([&](mlir::Float6E3M2FNType) {
4257 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4258 return llvm::Intrinsic::not_intrinsic;
4263ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4265 llvm::IRBuilderBase &builder) {
4267 bool hasRelu = op.getRelu();
4269 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4271 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4272 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4273 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4278 return {intId, std::move(args)};
4282ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4284 llvm::IRBuilderBase &builder) {
4286 bool hasRelu = op.getRelu();
4288 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4290 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4291 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4292 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4297 return {intId, std::move(args)};
4300llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4303 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4304 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4305 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4307 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4308 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4309 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4312 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF6x2Op");
4313 return llvm::Intrinsic::not_intrinsic;
4317llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4320 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4322 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4323 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4325 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4327 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4328 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4331 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF6x2Op");
4332 return llvm::Intrinsic::not_intrinsic;
4336#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4337 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4338 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4340#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4341 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4342 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4345ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4346 NVVM::SaturationMode sat,
bool hasRelu) {
4347 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4348 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4349 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4352 .Case([&](mlir::Float8E4M3FNType) {
4355 .Case([&](mlir::Float8E5M2Type) {
4358 .Case([&](mlir::Float8E8M0FNUType) {
4359 if (hasRoundingModeRZ)
4361 else if (hasRoundingModeRP)
4364 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4367 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4368 return llvm::Intrinsic::not_intrinsic;
4372#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4373 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4374 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4376llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4379 .Case([&](mlir::Float8E4M3FNType) {
4382 .Case([&](mlir::Float8E5M2Type) {
4386 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4387 return llvm::Intrinsic::not_intrinsic;
4392ConvertBF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4393 NVVM::FPRoundingMode rnd,
4394 NVVM::SaturationMode sat,
bool hasRelu) {
4395 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4397 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4398 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4399 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4400 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4401 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4405 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4407 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4408 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4410 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4412 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4413 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4415 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4416 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4417 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4418 return ue8m0x2IDs[
index];
4421 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
4422 return llvm::Intrinsic::not_intrinsic;
4428 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4430 bool hasRelu = curOp.getRelu();
4432 llvm::Intrinsic::ID intId =
4434 .Case([&](Float8E4M3FNType type) {
4435 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4436 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4438 .Case([&](Float8E5M2Type type) {
4439 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4440 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4443 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4444 return llvm::Intrinsic::not_intrinsic;
4447 llvm::Value *packedI16 =
4448 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4449 llvm::Type::getInt16Ty(builder.getContext()));
4451 return {intId, {packedI16}};
4456 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4458 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4459 llvm::Value *packedI16 =
4460 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4461 llvm::Type::getInt16Ty(builder.getContext()));
4463 return {intId, {packedI16}};
4468 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4470 bool hasRelu = curOp.getRelu();
4472 llvm::Intrinsic::ID intId =
4474 .Case([&](Float6E2M3FNType type) {
4475 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4476 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4478 .Case([&](Float6E3M2FNType type) {
4479 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4480 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4483 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4484 return llvm::Intrinsic::not_intrinsic;
4487 llvm::Value *packedI16 =
4488 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4489 llvm::Type::getInt16Ty(builder.getContext()));
4491 return {intId, {packedI16}};
4496 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4498 bool hasRelu = curOp.getRelu();
4500 llvm::Intrinsic::ID intId =
4502 .Case([&](Float4E2M1FNType type) {
4503 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4504 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4507 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4508 return llvm::Intrinsic::not_intrinsic;
4511 llvm::Value *extendedI16 =
4512 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4513 llvm::Type::getInt16Ty(builder.getContext()));
4515 return {intId, {extendedI16}};
4520 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4521 bool hasRelu = thisOp.getRelu();
4522 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4524 llvm::Intrinsic::ID
id =
4526 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4527 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4533 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4534 : builder.getInt16(0x7f7f));
4535 return {id, std::move(args)};
4540 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
4541 bool hasRelu = thisOp.getRelu();
4542 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4544 llvm::Intrinsic::ID
id =
4547 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4548 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4553 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4554 : builder.getInt16(0x7f7f));
4555 return {id, std::move(args)};
4560 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
4561 bool hasRelu = thisOp.getRelu();
4562 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4563 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
4565 static constexpr llvm::Intrinsic::ID ids[] = {
4566 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
4567 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4568 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4569 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4572 unsigned idx = (hasSat << 1) | hasRelu;
4576 llvm::Value *packedI16 =
4577 builder.CreateBitCast(mt.
lookupValue(thisOp.getSrc()),
4578 llvm::Type::getInt16Ty(builder.getContext()));
4579 args.push_back(packedI16);
4580 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4581 : builder.getInt16(0x7f7f));
4583 return {ids[idx], std::move(args)};
4587Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
4590 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4591 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4593 bool isShared = as == NVVMMemorySpace::Shared;
4594 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4596 llvm::Intrinsic::ID id;
4598 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4599 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4601 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4602 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4612llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4615 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4616 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4617 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4618 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4627#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4628 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4629 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4631#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4632 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4633 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4636Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
4639 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4640 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4642 bool isShared = as == NVVMMemorySpace::Shared;
4643 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
4644 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4646 llvm::Intrinsic::ID
id =
4653 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
4658#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4659 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4661#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4662 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4663 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4665#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4667 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4668 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4669 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4670 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4671 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4675ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4677 llvm::IRBuilderBase &builder) {
4678 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4679 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4680 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4681 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4682 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4684 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4685 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4686 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4687 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4688 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4690 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4691 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4692 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4693 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4694 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4697 unsigned hasRelu = op.getRelu() ? 1 : 0;
4698 unsigned hasSatFinite =
4699 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4702 unsigned idx = (hasSatFinite << 1) | hasRelu;
4707 if (op.getRandomBits())
4708 args.push_back(mt.
lookupValue(op.getRandomBits()));
4710 switch (op.getRnd()) {
4711 case FPRoundingMode::RN:
4712 return {rndRNIds[idx], std::move(args)};
4713 case FPRoundingMode::RZ:
4714 return {rndRZIds[idx], std::move(args)};
4715 case FPRoundingMode::RS:
4716 return {rndRSIds[idx], std::move(args)};
4718 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
4723ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4725 llvm::IRBuilderBase &builder) {
4726 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4727 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4728 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4729 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4730 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4732 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4733 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4734 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4735 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4736 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4738 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4739 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4740 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4741 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4742 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4745 unsigned hasRelu = op.getRelu() ? 1 : 0;
4746 unsigned hasSatFinite =
4747 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4750 unsigned idx = (hasSatFinite << 1) | hasRelu;
4755 if (op.getRandomBits())
4756 args.push_back(mt.
lookupValue(op.getRandomBits()));
4758 switch (op.getRnd()) {
4759 case FPRoundingMode::RN:
4760 return {rndRNIds[idx], std::move(args)};
4761 case FPRoundingMode::RZ:
4762 return {rndRZIds[idx], std::move(args)};
4763 case FPRoundingMode::RS:
4764 return {rndRSIds[idx], std::move(args)};
4766 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4770llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4772 bool hasRelu = getRelu();
4775 .Case([&](mlir::Float8E4M3FNType) {
4776 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4777 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4779 .Case([&](mlir::Float8E5M2Type) {
4780 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4781 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4784 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
4785 return llvm::Intrinsic::not_intrinsic;
4789llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4791 bool hasRelu = getRelu();
4794 .Case([&](mlir::Float6E2M3FNType) {
4795 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4796 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4798 .Case([&](mlir::Float6E3M2FNType) {
4799 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4800 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4803 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
4804 return llvm::Intrinsic::not_intrinsic;
4808llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4810 bool hasRelu = getRelu();
4813 .Case([&](mlir::Float4E2M1FNType) {
4814 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4815 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4818 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
4819 return llvm::Intrinsic::not_intrinsic;
4823llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
4824 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4825 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4826 auto srcFmt = curOp.getSrcFormat();
4827 auto mc = curOp.getMulticast();
4829 switch (curOp.getShape()) {
4830 case Tcgen05CpShape::SHAPE_128x256b:
4832 case Tcgen05CpShape::SHAPE_128x128b:
4834 case Tcgen05CpShape::SHAPE_4x256b:
4836 case Tcgen05CpShape::SHAPE_32x128b:
4838 case Tcgen05CpShape::SHAPE_64x128b:
4839 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4843 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
4850 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4852 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4857LogicalResult Tcgen05LdOp::verify() {
4859 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4862 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4863 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
4865 auto resTy = getRes().getType();
4866 unsigned resLen = isa<VectorType>(resTy)
4867 ? llvm::cast<VectorType>(resTy).getNumElements()
4870 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
4871 "{1} in tcgen05.ld Op",
4872 resLen, stringifyEnum(
getShape())));
4877LogicalResult Tcgen05StOp::verify() {
4879 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4882 auto valTy = getVal().getType();
4883 unsigned valLen = isa<VectorType>(valTy)
4884 ? llvm::cast<VectorType>(valTy).getNumElements()
4887 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
4888 "{1} in tcgen05.st Op",
4889 valLen, stringifyEnum(
getShape())));
4899 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
4900 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4901 rangeAttr.getLower(), rangeAttr.getUpper()});
4911 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4915 const llvm::APInt &lower = rangeAttr->getLower();
4916 const llvm::APInt &upper = rangeAttr->getUpper();
4919 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4920 unsigned bitWidth = lower.getBitWidth();
4921 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4922 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4924 "invalid range attribute: Lower == Upper, but they aren't min (")
4925 << llvm::toString(minVal, 10,
false) <<
") or max ("
4926 << llvm::toString(maxVal, 10,
false)
4927 <<
") value! This is an invalid constant range.";
4934 llvm::IRBuilderBase &builder) {
4935 return builder.CreateBitCast(arg,
4936 llvm::Type::getInt32Ty(builder.getContext()));
4941 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4948 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4949 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4950 unsigned type = (isASigned << 1) | isBSigned;
4951 const llvm::Intrinsic::ID ids[] = {
4952 llvm::Intrinsic::nvvm_idp4a_u_u,
4953 llvm::Intrinsic::nvvm_idp4a_u_s,
4954 llvm::Intrinsic::nvvm_idp4a_s_u,
4955 llvm::Intrinsic::nvvm_idp4a_s_s,
4957 return {ids[type], args};
4962 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4967 args.push_back(builder.getInt1(curOp.getBHi()));
4970 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4971 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4972 unsigned type = (isASigned << 1) | isBSigned;
4973 const llvm::Intrinsic::ID ids[] = {
4974 llvm::Intrinsic::nvvm_idp2a_u_u,
4975 llvm::Intrinsic::nvvm_idp2a_u_s,
4976 llvm::Intrinsic::nvvm_idp2a_s_u,
4977 llvm::Intrinsic::nvvm_idp2a_s_s,
4979 return {ids[type], args};
4983 llvm::IRBuilderBase &builder) {
4984 return builder.CreateAddrSpaceCast(
4985 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
4989PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4991 llvm::IRBuilderBase &builder) {
4992 using MemSpace = NVVM::NVVMMemorySpace;
4993 using CacheLevel = NVVM::PrefetchCacheLevel;
4995 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4996 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4997 op.getEvictPriority();
4998 unsigned addressSpace =
4999 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5007 if (op.getTensormap())
5008 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5010 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
5012 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5013 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5015 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5016 switch (*evictPriority) {
5017 case NVVM::CacheEvictionPriority::EvictLast:
5018 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5019 case NVVM::CacheEvictionPriority::EvictNormal:
5020 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5022 llvm_unreachable(
"Invalid cache eviction priority");
5026 switch (
static_cast<MemSpace
>(addressSpace)) {
5027 case MemSpace::Generic:
5028 return *cacheLevel == CacheLevel::L1
5030 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5031 case MemSpace::Global:
5032 return *cacheLevel == CacheLevel::L1
5034 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5036 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5037 case MemSpace::Local:
5038 return *cacheLevel == CacheLevel::L1
5040 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5042 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5044 llvm_unreachable(
"Invalid pointer address space");
5048bool NVVM::InlinePtxOp::getAsmValues(
5052 for (
auto arg : getReadWriteArgs())
5054 for (
auto arg : getResults())
5056 for (
auto arg : getReadOnlyArgs())
5063NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5065 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5067 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
5068 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
5070 llvm::Intrinsic::ID intrinsicID =
5071 curOp.getMulticast()
5073 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5074 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5076 return {intrinsicID, args};
5079NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5081 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5083 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
5085 llvm::Intrinsic::ID intrinsicID;
5087 switch (curOp.getQueryType()) {
5088 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5090 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5092 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5093 intrinsicID = llvm::Intrinsic::
5094 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5096 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5097 intrinsicID = llvm::Intrinsic::
5098 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5100 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5101 intrinsicID = llvm::Intrinsic::
5102 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5105 return {intrinsicID, args};
5110 llvm::IRBuilderBase &builder) {
5111 auto thisOp = cast<NVVM::PermuteOp>(op);
5112 NVVM::PermuteMode mode = thisOp.getMode();
5114 static constexpr llvm::Intrinsic::ID IDs[] = {
5115 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5116 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5117 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5118 llvm::Intrinsic::nvvm_prmt_rc16};
5120 unsigned modeIndex =
static_cast<unsigned>(mode);
5128 args.push_back(mt.
lookupValue(thisOp.getSelector()));
5130 return {IDs[modeIndex], args};
5135 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5139 if (thisOp.getOrd())
5140 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5141 if (thisOp.getNewValue())
5142 args.push_back(mt.
lookupValue(thisOp.getNewValue()));
5143 if (
auto attr = thisOp.getNewValueAttr()) {
5146 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5147 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5148 TensormapFillModeAttr>([](
auto attr) {
5149 return static_cast<unsigned>(attr.getValue());
5151 .Default([](
auto attr) {
5152 llvm_unreachable(
"Invalid attribute type");
5155 args.push_back(builder.getInt32(val));
5158 static constexpr llvm::Intrinsic::ID IDs[] = {
5159 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5160 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5161 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5162 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5163 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5164 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5165 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5166 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5167 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5168 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5169 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5172 unsigned fieldIndex =
static_cast<unsigned>(thisOp.getField());
5174 return {IDs[fieldIndex], args};
5183 llvm::IRBuilderBase &builder) {
5185 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5188 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5191 const bool isATensor = isa<llvm::PointerType>(
A->getType());
5194 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5195 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5196 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5198 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5199 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5200 using IsATensorArray = std::array<CtaGroupArray, 2>;
5201 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5202 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5205 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5211 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
5213 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
5217 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5218 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5222 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5223 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5229 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
5231 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
5235 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5236 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5240 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5241 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5247 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5250 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5255 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5257 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5262 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5264 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5270 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5274 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5279 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5281 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5285 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5287 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5290 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5291 bool hasScaleInputD = ScaleInputD !=
nullptr;
5293 llvm::Value *DisableOutputLane =
5295 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5297 const unsigned ctaGroup =
5300 llvm::Intrinsic::ID ID =
5301 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5302 [ctaGroup - 1][thisOp.getAShift()];
5304 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
5307 args.push_back(ScaleInputD);
5309 if (hasDisableOutputLane)
5310 args.push_back(DisableOutputLane);
5312 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5314 if (!hasDisableOutputLane)
5315 args.push_back(builder.getInt32(ctaGroup));
5318 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5325 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
5326 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
5328 if (disableOutputLane) {
5329 mlir::VectorType disableOutputLaneType =
5330 cast<mlir::VectorType>(disableOutputLane.
getType());
5331 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5332 disableOutputLaneType.getNumElements() != 4) ||
5333 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5334 disableOutputLaneType.getNumElements() != 8))
5335 return emitError(loc) <<
"Disable Output Lane of length "
5336 << disableOutputLaneType.getNumElements()
5337 <<
" is incompatible with CtaGroupAttr";
5340 if (hasAShift && !isATensor)
5342 loc,
"A-shift can be applied only when matrix A is in tensor memory");
5344 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5345 collectorOp == Tcgen05MMACollectorOp::USE))
5347 loc,
"Cannot use collector buffer operation fill or use with ashift");
5352LogicalResult Tcgen05MMAOp::verify() {
5354 getDisableOutputLane(), getCtaGroup(), getAShift(),
5355 getCollectorOp(), getLoc());
5365 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5368 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5371 bool isATensor = isa<llvm::PointerType>(
A->getType());
5374 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5375 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5376 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5377 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5379 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5380 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5381 using IsATensorArray = std::array<CtaGroupArray, 2>;
5382 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5383 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5386 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5392 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
5394 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
5398 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5399 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5403 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5404 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5410 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5413 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5418 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5419 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5423 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5424 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5431 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5435 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5440 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5442 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5447 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5449 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5455 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5459 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5464 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5466 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5470 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5472 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5475 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5476 bool hasScaleInputD = ScaleInputD !=
nullptr;
5478 llvm::Value *DisableOutputLane =
5480 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5485 llvm::Intrinsic::ID ID =
5486 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5487 [ctaGroup - 1][thisOp.getAShift()];
5489 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5492 args.push_back(ScaleInputD);
5494 if (hasDisableOutputLane)
5495 args.push_back(DisableOutputLane);
5497 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5499 if (!hasDisableOutputLane)
5500 args.push_back(builder.getInt32(ctaGroup));
5503 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5508LogicalResult Tcgen05MMASparseOp::verify() {
5510 getDisableOutputLane(), getCtaGroup(), getAShift(),
5511 getCollectorOp(), getLoc());
5521 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5524 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5527 bool isATensor = isa<llvm::PointerType>(
A->getType());
5530 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5531 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5532 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5533 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5534 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5535 args.push_back(builder.getInt32(
5538 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5540 auto kind = thisOp.getKind();
5541 auto blockScale = thisOp.getBlockScale();
5542 llvm::Intrinsic::ID ID = [&]() {
5543 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5544 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5545 return isATensor ? llvm::Intrinsic::
5546 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5548 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5549 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5552 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5554 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5556 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5557 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5559 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5560 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5561 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5562 return isATensor ? llvm::Intrinsic::
5563 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5565 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5567 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5568 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5571 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5573 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5575 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5578 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5580 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5583 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
5590 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5591 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
5592 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5593 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5594 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
5596 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5597 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5599 llvm::formatv(
"{} kind does not support block16 attribute",
5600 stringifyEnum(kind)));
5605LogicalResult Tcgen05MMABlockScaleOp::verify() {
5607 getBlockScale(), getLoc());
5617 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5620 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5623 bool isATensor = isa<llvm::PointerType>(
A->getType());
5626 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5627 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5628 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5629 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5630 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5631 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5632 args.push_back(builder.getInt32(
5635 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5637 auto kind = thisOp.getKind();
5638 auto blockScale = thisOp.getBlockScale();
5639 llvm::Intrinsic::ID ID = [&]() {
5640 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5641 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5642 return isATensor ? llvm::Intrinsic::
5643 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5645 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5646 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5649 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5651 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5653 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5654 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5655 return isATensor ? llvm::Intrinsic::
5656 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5658 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5659 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5662 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5664 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5666 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5667 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5670 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5672 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5674 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5677 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5679 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5682 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
5688LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5690 getBlockScale(), getLoc());
5700 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5703 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5706 bool isATensor = isa<llvm::PointerType>(
A->getType());
5709 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5710 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5711 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5713 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5717 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5718 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5720 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5721 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5723 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5725 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5727 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5739 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5742 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5745 bool isATensor = isa<llvm::PointerType>(
A->getType());
5748 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5749 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5750 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5751 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5753 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5758 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5759 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5761 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5762 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5764 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5766 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5768 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5777#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5778 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5782 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5785 mlir::VectorType VecResTy =
5786 cast<mlir::VectorType>(thisOp.getData().getType());
5787 unsigned Num = VecResTy.getNumElements();
5788 bool IsFloat = thisOp.getRedVal().getType().isF32();
5790 llvm::Intrinsic::ID Shape32x32b[][2] = {
5801 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5812 NVVM::Tcgen05LdStShape
shape = thisOp.getShape();
5813 unsigned ID = [&]() {
5816 unsigned idx = std::log2(Num);
5818 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5819 return Shape32x32b[idx][IsFloat];
5820 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5821 return Shape16x32bx2[idx][IsFloat];
5823 llvm_unreachable(
"unhandled tcgen05.ld lowering");
5829 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5830 args.push_back(mt.
lookupValue(thisOp.getOffset()));
5833 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5836 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getAbs())));
5837 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getNan())));
5842LogicalResult Tcgen05LdRedOp::verify() {
5843 VectorType data = cast<VectorType>(getData().
getType());
5844 Type redVal = getRedVal().getType();
5846 if (data.getElementType() != redVal)
5848 "type of reduction value and element type of vector data should match");
5850 if (getOp() != NVVM::ReductionKind::MIN &&
5851 getOp() != NVVM::ReductionKind::MAX)
5852 return emitError(
"only min and max reduction kinds are supported");
5854 if (redVal.
isInteger() && (getAbs() || getNan())) {
5855 return emitError(
"abs or nan is only applicable for f32 type");
5865void NVVMDialect::initialize() {
5868#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5871#define GET_ATTRDEF_LIST
5872#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5877 allowUnknownOperations();
5878 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5879 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5882LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
5884 StringAttr attrName = attr.
getName();
5886 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5887 if (!isa<LLVM::LLVMFuncOp>(op)) {
5888 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
5889 <<
"' attribute attached to unexpected op";
5894 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5895 attrName == NVVMDialect::getReqntidAttrName() ||
5896 attrName == NVVMDialect::getClusterDimAttrName()) {
5897 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
5898 if (!values || values.empty() || values.size() > 3) {
5901 <<
"' attribute must be integer array with maximum 3 index";
5906 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5907 attrName == NVVMDialect::getMaxnregAttrName() ||
5908 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5909 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
5911 <<
"'" << attrName <<
"' attribute must be integer constant";
5915 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5916 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
5917 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
5919 <<
"'" << attrName <<
"' attribute must be used along with "
5920 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
5921 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
5928LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
5929 unsigned regionIndex,
5932 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5936 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
5937 StringAttr attrName = argAttr.
getName();
5938 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5942 <<
"' attribute must be present only on kernel arguments";
5944 if (!isa<UnitAttr>(argAttr.
getValue()))
5945 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
5946 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5949 <<
"' attribute requires the argument to also have attribute '"
5950 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
5961unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
5962 return static_cast<unsigned>(getValue());
5965bool NVVMMemorySpaceAttr::isValidLoad(
5966 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5967 const ::mlir::DataLayout *dataLayout,
5973bool NVVMMemorySpaceAttr::isValidStore(
5974 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5975 const ::mlir::DataLayout *dataLayout,
5981bool NVVMMemorySpaceAttr::isValidAtomicOp(
5982 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
5983 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5986 assert(
false &&
"unimplemented, see TODO in the source.");
5990bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5991 Type type, ptr::AtomicOrdering successOrdering,
5992 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5993 const ::mlir::DataLayout *dataLayout,
5996 assert(
false &&
"unimplemented, see TODO in the source.");
6000bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6004 assert(
false &&
"unimplemented, see TODO in the source.");
6008bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6013 assert(
false &&
"unimplemented, see TODO in the source.");
6022 int optLevel, StringRef triple, StringRef chip,
6023 StringRef features, DictionaryAttr flags,
6025 if (optLevel < 0 || optLevel > 3) {
6026 emitError() <<
"The optimization level must be a number between 0 and 3.";
6029 if (triple.empty()) {
6030 emitError() <<
"The target triple cannot be empty.";
6034 emitError() <<
"The target chip cannot be empty.";
6037 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6038 return mlir::isa_and_nonnull<StringAttr>(attr);
6040 emitError() <<
"All the elements in the `link` array must be strings.";
6046LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
6047 if (!getVerifyTarget())
6050 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6053 "NVVM target attribute must be attached to a GPU module");
6056 const NVVMCheckSMVersion targetSMVersion =
6060 "Minimum NVVM target SM version is sm_20");
6064 ->
walk([&](Operation *op) {
6065 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6066 const NVVMCheckSMVersion requirement =
6067 reqOp.getRequiredMinSMVersion();
6069 op->
emitOpError() <<
"is not supported on " << getChip();
6081#define GET_OP_CLASSES
6082#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6084#define GET_ATTRDEF_CLASSES
6085#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.