31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
315 inferredReturnTypes.push_back(IntegerType::get(context, 64));
320MBarrierArriveOp::inferReturnTypes(
MLIRContext *context,
321 std::optional<Location> location,
322 MBarrierArriveOp::Adaptor adaptor,
325 inferredReturnTypes);
328LogicalResult MBarrierArriveDropOp::inferReturnTypes(
329 MLIRContext *context, std::optional<Location> location,
330 MBarrierArriveDropOp::Adaptor adaptor,
333 inferredReturnTypes);
336LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
337 MLIRContext *context, std::optional<Location> location,
338 MBarrierArriveExpectTxOp::Adaptor adaptor,
342 if (adaptor.getPredicate())
345 inferredReturnTypes);
348LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> location,
350 MBarrierArriveDropExpectTxOp::Adaptor adaptor,
353 inferredReturnTypes);
363 return inferred == actual;
372bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(
TypeRange l,
376bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(
TypeRange l,
381LogicalResult MBarrierExpectTxOp::verify() {
385LogicalResult MBarrierCompleteTxOp::verify() {
389LogicalResult MBarrierTestWaitOp::verify() {
393LogicalResult MBarrierTryWaitOp::verify() {
397LogicalResult ConvertFloatToTF32Op::verify() {
398 using RndMode = NVVM::FPRoundingMode;
402 return emitError(
"Relu not supported with rna rounding mode.");
409 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
414LogicalResult ConvertF32x2ToF6x2Op::verify() {
417 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
419 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
420 << mlir::Float6E3M2FNType::get(ctx)
421 <<
" types are supported for conversions from f32x2 to f6x2.";
426LogicalResult ConvertF32x2ToF8x2Op::verify() {
427 using RndMode = NVVM::FPRoundingMode;
428 using SatMode = NVVM::SaturationMode;
430 bool isRoundingModeRN = getRnd() == RndMode::RN;
431 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
432 bool isRoundingModeRP = getRnd() == RndMode::RP;
433 bool isSatFinite = getSat() == SatMode::SATFINITE;
435 bool hasRelu = getRelu();
440 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
442 if (!isRoundingModeRN) {
443 return emitOpError(
"Only RN rounding mode is supported for "
444 "conversions from f32x2 to ")
445 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
446 << mlir::Float8E5M2Type::get(ctx) <<
" types";
449 return emitOpError(
"Only SATFINITE saturation mode is supported "
452 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
453 << mlir::Float8E5M2Type::get(ctx) <<
" types";
457 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
458 if (!(isRoundingModeRZ || isRoundingModeRP)) {
459 return emitOpError(
"Only RZ and RP rounding modes are supported for "
460 "conversions from f32x2 to ")
461 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
464 return emitOpError(
"relu not supported for conversions to ")
465 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
471 << mlir::Float8E4M3FNType::get(ctx) <<
", "
472 << mlir::Float8E5M2Type::get(ctx) <<
", and "
473 << mlir::Float8E8M0FNUType::get(ctx)
475 "supported for conversions from f32x2 to f8x2";
479LogicalResult ConvertF16x2ToF8x2Op::verify() {
482 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
484 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
485 << mlir::Float8E5M2Type::get(ctx)
486 <<
" types are supported for conversions from f16x2 to f8x2.";
491LogicalResult ConvertBF16x2ToF8x2Op::verify() {
492 using RndMode = NVVM::FPRoundingMode;
493 using SatMode = NVVM::SaturationMode;
495 bool isRoundingModeRN = getRnd() == RndMode::RN;
496 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
497 bool isRoundingModeRP = getRnd() == RndMode::RP;
498 bool isSatFinite = getSat() == SatMode::SATFINITE;
499 bool hasRelu = getRelu();
504 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
506 if (!isRoundingModeRN)
507 return emitOpError(
"Only RN rounding mode is supported for "
508 "conversions from bf16x2 to ")
509 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
510 << mlir::Float8E5M2Type::get(ctx) <<
" types";
512 return emitOpError(
"Only SATFINITE saturation mode is supported "
513 "for conversions from bf16x2 to ")
514 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
515 << mlir::Float8E5M2Type::get(ctx) <<
" types";
518 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
519 if (!(isRoundingModeRZ || isRoundingModeRP))
520 return emitOpError(
"Only RZ and RP rounding modes are supported for "
521 "conversions from bf16x2 to ")
522 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
524 return emitOpError(
"relu not supported for conversions to ")
525 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
529 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
534LogicalResult ConvertF32x2ToF4x2Op::verify() {
537 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
539 << mlir::Float4E2M1FNType::get(ctx)
540 <<
" type is supported for conversions from f32x2 to f4x2.";
545LogicalResult ConvertF8x2ToBF16x2Op::verify() {
547 if (llvm::isa<Float8E8M0FNUType>(getSrcType())) {
548 if (getSat() != SaturationMode::NONE)
550 "Only NONE saturation mode is supported for conversions from ")
551 << Float8E8M0FNUType::get(ctx) <<
" type";
552 if (getScaleFactor())
553 return emitOpError(
"scaleFactor not supported for conversions from ")
554 << Float8E8M0FNUType::get(ctx) <<
" type";
556 return emitOpError(
"relu not supported for conversions from ")
557 << Float8E8M0FNUType::get(ctx) <<
" type";
563LogicalResult PermuteOp::verify() {
564 using Mode = NVVM::PermuteMode;
565 bool hasHi =
static_cast<bool>(getHi());
572 return emitError(
"mode '") << getMode() <<
"' requires 'hi' operand.";
580 << getMode() <<
"' does not accept 'hi' operand.";
595 static constexpr FPRoundingMode validRndModes[] = {
596 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
598 if (!llvm::is_contained(validRndModes, rnd)) {
600 "Only RN, RZ, and RS rounding modes are supported for "
601 "conversions from f32x2 to ")
605 if (rnd == FPRoundingMode::RS) {
606 if (!hasRandomBits) {
607 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
612 "random_bits not supported for RN and RZ rounding modes.");
619LogicalResult ConvertF32x2ToF16x2Op::verify() {
621 getRandomBits() ?
true :
false, *
this);
624LogicalResult ConvertF32x2ToBF16x2Op::verify() {
626 getRandomBits() ?
true :
false, *
this);
629LogicalResult ConvertF32x4ToF8x4Op::verify() {
632 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
634 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
635 << mlir::Float8E5M2Type::get(ctx)
636 <<
" types are supported for conversions from f32x4 to f8x4.";
641LogicalResult ConvertF32x4ToF6x4Op::verify() {
644 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
646 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
647 << mlir::Float6E3M2FNType::get(ctx)
648 <<
" types are supported for conversions from f32x4 to f6x4.";
653LogicalResult ConvertF32x4ToF4x4Op::verify() {
656 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
657 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
658 <<
" type is supported for conversions from "
664LogicalResult BulkStoreOp::verify() {
665 if (getInitVal() != 0)
666 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
670LogicalResult PMEventOp::verify() {
671 auto eventId = getEventId();
672 auto maskedEventId = getMaskedEventId();
673 if (!maskedEventId && !eventId) {
674 return emitOpError() <<
"either `id` or `mask` must be set";
677 if (maskedEventId && eventId) {
678 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
682 if (eventId < 0 || eventId > 15) {
683 return emitOpError() <<
"`id` must be between 0 and 15";
687 return llvm::success();
693std::optional<mlir::NVVM::MMATypes>
694MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
696 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
697 if (operandElType.
isF64())
698 return NVVM::MMATypes::f64;
699 if (operandElType.
isF16() || operandElType == half2Type)
700 return NVVM::MMATypes::f16;
701 if (operandElType.
isF32() && isAccumulator)
702 return NVVM::MMATypes::f32;
703 if (operandElType.
isF32() && !isAccumulator)
704 return NVVM::MMATypes::tf32;
705 if (llvm::isa<IntegerType>(operandElType)) {
707 return NVVM::MMATypes::s32;
711 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
712 if (structType.getBody().empty())
714 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
721 return (type == MMATypes::u4 || type == MMATypes::s4);
725 return (type == MMATypes::u8 || type == MMATypes::s8);
730 type == MMATypes::s32;
733MMATypes MmaOp::accumPtxType() {
734 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
735 getODSOperands(2).getTypes().front(),
true);
736 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
740MMATypes MmaOp::resultPtxType() {
741 std::optional<mlir::NVVM::MMATypes> val =
742 inferOperandMMAType(getResult().
getType(),
true);
743 assert(val.has_value() &&
"result PTX type should always be inferrable");
749 struct MMAOperandFragment {
750 StringRef operandName;
751 StringRef ptxTypeAttr;
752 SmallVector<Value, 4> regs;
753 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
754 : operandName(name), ptxTypeAttr(ptxTypeName) {}
757 std::array<MMAOperandFragment, 3> frags{
758 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
759 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
760 MMAOperandFragment(
"C",
"")};
762 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
764 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
765 auto &frag = frags[fragIdx];
766 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
767 for (
auto operandIdx = varOperandSpec.first;
768 operandIdx < varOperandSpec.first + varOperandSpec.second;
770 frag.regs.push_back(this->getOperand(operandIdx));
771 if (operandIdx == 0) {
772 regTypes.push_back(this->getOperand(operandIdx).
getType());
775 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
776 regTypes.back(), fragIdx >= 2);
778 ignoreAttrNames.push_back(frag.ptxTypeAttr);
781 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
782 p <<
" " << frag.operandName;
788 for (
const auto &frag : frags) {
789 printMmaOperand(frag);
797 frags[1].regs[0].getType(),
798 frags[2].regs[0].getType()},
807 std::optional<MMAIntOverflow> intOverflow,
808 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
809 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
811 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
816 result.addOperands(operandA);
817 result.addOperands(operandB);
818 result.addOperands(operandC);
820 if (multiplicandPtxTypes) {
821 result.addAttribute(
"multiplicandAPtxType",
822 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
823 result.addAttribute(
"multiplicandBPtxType",
824 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
826 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
827 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
828 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
829 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
832 if (multiplicandLayouts) {
833 result.addAttribute(
"layoutA",
834 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
835 result.addAttribute(
"layoutB",
836 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
838 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
839 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
842 if (intOverflow.has_value())
843 result.addAttribute(
"intOverflowBehavior",
844 MMAIntOverflowAttr::get(ctx, *intOverflow));
845 if (b1Op.has_value())
846 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
848 result.addTypes(resultType);
850 MmaOp::getOperandSegmentSizeAttr(),
852 static_cast<int32_t>(operandB.size()),
853 static_cast<int32_t>(operandC.size())}));
861 struct MMAOperandFragment {
862 std::optional<MMATypes> elemtype;
863 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
864 SmallVector<Type> regTypes;
868 std::array<MMAOperandFragment, 4> frags;
874 MMAOperandFragment &frag) -> LogicalResult {
904 if (operandTypes.size() != 3)
907 "expected one type for each operand segment but got " +
908 Twine(operandTypes.size()) +
" types");
909 for (
const auto &iter : llvm::enumerate(operandTypes)) {
910 auto &frag = frags[iter.index()];
911 frag.regTypes.resize(frag.regs.size(), iter.value());
915 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
922 frags[3].elemtype = inferOperandMMAType(resultType,
true);
924 std::array<StringRef, 2> names{
"multiplicandAPtxType",
925 "multiplicandBPtxType"};
926 for (
unsigned idx = 0; idx < names.size(); idx++) {
927 const auto &frag = frags[idx];
928 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
929 if (!frag.elemtype.has_value() && !attr.has_value()) {
932 "attribute " + names[idx] +
933 " is not provided explicitly and cannot be inferred");
935 if (!attr.has_value())
937 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
940 result.addTypes(resultType);
941 if (!namedAttributes.
empty())
942 result.addAttributes(namedAttributes);
943 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
945 static_cast<int32_t>(frags[0].regs.size()),
946 static_cast<int32_t>(frags[1].regs.size()),
947 static_cast<int32_t>(frags[2].regs.size()),
952LogicalResult MmaOp::verify() {
954 auto f16Ty = Float16Type::get(context);
955 auto i32Ty = IntegerType::get(context, 32);
956 auto f16x2Ty = VectorType::get(2, f16Ty);
957 auto f32Ty = Float32Type::get(context);
958 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
959 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
962 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
965 auto f16x2x2StructTy =
966 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
968 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
970 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
972 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
973 getShapeAttr().getK()};
979 AllowedShapes allowedShapes;
980 AllowedTypes expectedA;
981 AllowedTypes expectedB;
982 AllowedTypes expectedC;
987 if (mmaShape[0] == 16) {
989 Type multiplicandFragType;
990 switch (*getMultiplicandAPtxType()) {
993 multiplicandFragType = i32Ty;
994 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
995 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
999 multiplicandFragType = i32Ty;
1000 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1001 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1005 multiplicandFragType = f16x2Ty;
1006 expectedResult.push_back(f16x2x2StructTy);
1007 expectedResult.push_back(f32x4StructTy);
1021 return emitError(
"invalid shape or multiplicand type: ")
1022 << getMultiplicandAPtxType().value();
1026 expectedResult.push_back(s32x4StructTy);
1027 expectedC.emplace_back(4, i32Ty);
1028 multiplicandFragType = i32Ty;
1030 expectedC.emplace_back(2, f16x2Ty);
1031 expectedC.emplace_back(4, f32Ty);
1034 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
1035 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1036 expectedA.emplace_back(unitA, multiplicandFragType);
1037 expectedB.emplace_back(unitB, multiplicandFragType);
1038 allowedShapes.push_back({16, 8, kFactor});
1039 allowedShapes.push_back({16, 8, kFactor * 2});
1041 if (resultPtxType() != accumPtxType())
1046 if (mmaShape[0] == 8) {
1047 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1048 expectedA.emplace_back(2, f16x2Ty);
1049 expectedB.emplace_back(2, f16x2Ty);
1050 expectedResult.push_back(f16x2x4StructTy);
1051 expectedResult.push_back(f32x8StructTy);
1052 expectedC.emplace_back(4, f16x2Ty);
1053 expectedC.emplace_back(8, f32Ty);
1054 allowedShapes.push_back({8, 8, 4});
1056 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1057 Type f64Ty = Float64Type::get(context);
1058 expectedA.emplace_back(1, f64Ty);
1059 expectedB.emplace_back(1, f64Ty);
1060 expectedC.emplace_back(2, f64Ty);
1061 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1063 allowedShapes.push_back({8, 8, 4});
1066 expectedA.push_back({i32Ty});
1067 expectedB.push_back({i32Ty});
1068 expectedC.push_back({i32Ty, i32Ty});
1069 expectedResult.push_back(s32x2StructTy);
1071 allowedShapes.push_back({8, 8, 32});
1073 allowedShapes.push_back({8, 8, 16});
1074 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1075 allowedShapes.push_back({8, 8, 128});
1079 std::string errorMessage;
1080 llvm::raw_string_ostream errorStream(errorMessage);
1083 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1084 !llvm::is_contained(allowedShapes, mmaShape)) {
1085 errorStream <<
"unimplemented variant for MMA shape <";
1086 llvm::interleaveComma(mmaShape, errorStream);
1092 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1093 for (
const auto &iter : llvm::enumerate(
1095 auto spec = this->getODSOperandIndexAndLength(iter.index());
1097 operand_type_begin() + spec.first +
1099 bool match = llvm::is_contained(iter.value(), operandTySeg);
1102 errorStream <<
"Could not match types for the "
1103 << operandNames[iter.index()]
1104 <<
" operands; expected one of ";
1105 for (
const auto &x : iter.value()) {
1106 errorStream << x.size() <<
"x" << x[0] <<
" ";
1108 errorStream <<
"but got ";
1109 llvm::interleaveComma(operandTySeg, errorStream);
1115 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1116 return expectedResultType == getResult().getType();
1119 <<
"Could not match allowed types for the result; expected one of ";
1120 llvm::interleaveComma(expectedResult, errorStream);
1121 errorStream <<
" but got " << getResult().getType();
1126 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1127 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1135 if (!getIntOverflowBehavior())
1137 getIntOverflowBehaviorAttrName().strref() +
1145 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1146 getMultiplicandAPtxType() == MMATypes::f16);
1148 if (!isM8N8K4_F16) {
1150 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1151 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1152 "layoutB = #nvvm.mma_layout<col> for shape <")
1153 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1154 <<
"> with element types " << *getMultiplicandAPtxType() <<
" and "
1155 << *getMultiplicandBPtxType()
1156 <<
". Only m8n8k4 with f16 supports other layouts.";
1163MMATypes MmaSpOp::accumPtxType() {
1164 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1165 getODSOperands(2).getTypes().front(),
true);
1166 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1170MMATypes MmaSpOp::resultPtxType() {
1171 std::optional<mlir::NVVM::MMATypes> val =
1172 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1173 assert(val.has_value() &&
"result PTX type should always be inferrable");
1179 llvm::IRBuilderBase &builder) {
1180 auto thisOp = cast<NVVM::MmaSpOp>(op);
1188 auto intId = MmaSpOp::getIntrinsicID(
1189 thisOp.getShape().getM(), thisOp.getShape().getN(),
1190 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1191 thisOp.getOrderedMetadata(), thisOp.getKind(),
1192 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1193 thisOp.accumPtxType(), thisOp.resultPtxType());
1195 return {intId, args};
1200 struct MMAOperandFragment {
1201 StringRef operandName;
1202 StringRef ptxTypeAttr;
1203 SmallVector<Value, 4> regs;
1204 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1205 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1208 std::array<MMAOperandFragment, 5> frags{
1209 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1210 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1211 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1212 MMAOperandFragment(
"selector",
"")};
1214 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1217 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1218 auto &frag = frags[fragIdx];
1219 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1220 for (
auto operandIdx = varOperandSpec.first;
1221 operandIdx < varOperandSpec.first + varOperandSpec.second;
1223 frag.regs.push_back(this->getOperand(operandIdx));
1224 if (operandIdx == varOperandSpec.first) {
1225 regTypes.push_back(this->getOperand(operandIdx).
getType());
1228 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1229 regTypes.back(), fragIdx >= 2);
1231 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1235 frags[3].regs.push_back(getSparseMetadata());
1236 frags[4].regs.push_back(getSparsitySelector());
1238 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1239 p <<
" " << frag.operandName;
1245 for (
const auto &frag : frags)
1246 printMmaSpOperand(frag);
1251 for (
int i = 0; i < 3; ++i) {
1256 p <<
") -> " << getResult().getType();
1263 std::optional<MMAIntOverflow> intOverflow,
1264 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1266 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1271 result.addOperands(operandA);
1272 result.addOperands(operandB);
1273 result.addOperands(operandC);
1274 result.addOperands(sparseMetadata);
1275 result.addOperands(sparsitySelector);
1277 if (multiplicandPtxTypes) {
1278 result.addAttribute(
"multiplicandAPtxType",
1279 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1280 result.addAttribute(
"multiplicandBPtxType",
1281 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1283 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1284 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1285 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1286 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1289 if (intOverflow.has_value())
1290 result.addAttribute(
"intOverflowBehavior",
1291 MMAIntOverflowAttr::get(ctx, *intOverflow));
1293 result.addTypes(resultType);
1295 MmaSpOp::getOperandSegmentSizeAttr(),
1297 static_cast<int32_t>(operandB.size()),
1298 static_cast<int32_t>(operandC.size()), 1,
1303 struct MMAOperandFragment {
1304 std::optional<MMATypes> elemtype;
1305 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1306 SmallVector<Type> regTypes;
1310 std::array<MMAOperandFragment, 6> frags;
1315 auto parseMmaSpOperand = [&](StringRef operandName,
1316 MMAOperandFragment &frag) -> LogicalResult {
1327 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1329 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1331 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1333 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1335 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1351 if (operandTypes.size() != 3)
1354 "expected one type for each operand segment but got " +
1355 Twine(operandTypes.size()) +
" types");
1356 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1357 auto &frag = frags[iter.index()];
1358 frag.regTypes.resize(frag.regs.size(), iter.value());
1363 MmaOp::inferOperandMMAType(frag.regTypes[0],
1371 MmaOp::inferOperandMMAType(resultType,
true);
1386 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1387 "multiplicandBPtxType"};
1388 for (
unsigned idx = 0; idx < names.size(); idx++) {
1389 const auto &frag = frags[idx];
1390 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1391 if (!frag.elemtype.has_value() && !attr.has_value()) {
1394 "attribute " + names[idx] +
1395 " is not provided explicitly and cannot be inferred");
1397 if (!attr.has_value())
1399 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1402 result.addTypes(resultType);
1403 if (!namedAttributes.
empty())
1404 result.addAttributes(namedAttributes);
1405 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1407 static_cast<int32_t>(frags[0].regs.size()),
1408 static_cast<int32_t>(frags[1].regs.size()),
1409 static_cast<int32_t>(frags[2].regs.size()),
1416LogicalResult MmaSpOp::verify() {
1418 auto f16Ty = Float16Type::get(context);
1419 auto i32Ty = IntegerType::get(context, 32);
1420 auto f16x2Ty = VectorType::get(2, f16Ty);
1421 auto f32Ty = Float32Type::get(context);
1422 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1423 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1425 auto s32x4StructTy =
1426 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1427 auto f32x8StructTy =
1429 auto f16x2x2StructTy =
1430 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1431 auto f32x4StructTy =
1432 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1433 auto s32x2StructTy =
1434 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1436 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1437 getShapeAttr().getK()};
1443 AllowedShapes allowedShapes;
1444 AllowedTypes expectedA;
1445 AllowedTypes expectedB;
1446 AllowedTypes expectedC;
1451 if (mmaShape[0] == 16) {
1453 Type multiplicandFragType;
1454 switch (*getMultiplicandAPtxType()) {
1455 case MMATypes::tf32:
1457 multiplicandFragType = i32Ty;
1458 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1459 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1461 allowedShapes.push_back({16, 8, 8});
1462 allowedShapes.push_back({16, 8, 16});
1464 case MMATypes::bf16:
1466 multiplicandFragType = i32Ty;
1467 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1468 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1470 allowedShapes.push_back({16, 8, 16});
1471 allowedShapes.push_back({16, 8, 32});
1475 multiplicandFragType = f16x2Ty;
1476 expectedResult.push_back(f16x2x2StructTy);
1477 expectedResult.push_back(f32x4StructTy);
1479 allowedShapes.push_back({16, 8, 16});
1480 allowedShapes.push_back({16, 8, 32});
1486 allowedShapes.push_back({16, 8, 64});
1487 allowedShapes.push_back({16, 8, 128});
1493 allowedShapes.push_back({16, 8, 32});
1494 allowedShapes.push_back({16, 8, 64});
1496 case MMATypes::e4m3:
1497 case MMATypes::e5m2:
1498 case MMATypes::e3m2:
1499 case MMATypes::e2m3:
1500 case MMATypes::e2m1:
1502 multiplicandFragType = i32Ty;
1503 expectedResult.push_back(f16x2x2StructTy);
1504 expectedResult.push_back(f32x4StructTy);
1506 allowedShapes.push_back({16, 8, 64});
1509 return emitError(
"invalid shape or multiplicand type: ")
1510 << getMultiplicandAPtxType().value();
1514 expectedResult.push_back(s32x4StructTy);
1515 expectedC.emplace_back(4, i32Ty);
1516 multiplicandFragType = i32Ty;
1517 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1518 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1520 expectedC.emplace_back(2, f16x2Ty);
1521 expectedC.emplace_back(4, f32Ty);
1523 expectedC.emplace_back(2, f16x2Ty);
1524 expectedC.emplace_back(4, f32Ty);
1529 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1530 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1531 expectedA.emplace_back(unitA, multiplicandFragType);
1532 expectedB.emplace_back(unitB, multiplicandFragType);
1534 if (resultPtxType() != accumPtxType())
1539 if (mmaShape[0] == 8) {
1540 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1541 expectedA.emplace_back(2, f16x2Ty);
1542 expectedB.emplace_back(2, f16x2Ty);
1543 expectedResult.push_back(f16x2x4StructTy);
1544 expectedResult.push_back(f32x8StructTy);
1545 expectedC.emplace_back(4, f16x2Ty);
1546 expectedC.emplace_back(8, f32Ty);
1547 allowedShapes.push_back({8, 8, 4});
1549 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1550 Type f64Ty = Float64Type::get(context);
1551 expectedA.emplace_back(1, f64Ty);
1552 expectedB.emplace_back(1, f64Ty);
1553 expectedC.emplace_back(2, f64Ty);
1554 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1556 allowedShapes.push_back({8, 8, 4});
1559 expectedA.push_back({i32Ty});
1560 expectedB.push_back({i32Ty});
1561 expectedC.push_back({i32Ty, i32Ty});
1562 expectedResult.push_back(s32x2StructTy);
1564 allowedShapes.push_back({8, 8, 32});
1566 allowedShapes.push_back({8, 8, 16});
1570 std::string errorMessage;
1571 llvm::raw_string_ostream errorStream(errorMessage);
1574 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1575 !llvm::is_contained(allowedShapes, mmaShape)) {
1576 errorStream <<
"unimplemented variant for MMA shape <";
1577 llvm::interleaveComma(mmaShape, errorStream);
1583 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1584 for (
const auto &iter : llvm::enumerate(
1586 auto spec = this->getODSOperandIndexAndLength(iter.index());
1588 operand_type_begin() + spec.first +
1590 bool match = llvm::is_contained(iter.value(), operandTySeg);
1593 errorStream <<
"Could not match types for the "
1594 << operandNames[iter.index()]
1595 <<
" operands; expected one of ";
1596 for (
const auto &x : iter.value()) {
1597 errorStream << x.size() <<
"x" << x[0] <<
" ";
1599 errorStream <<
"but got ";
1600 llvm::interleaveComma(operandTySeg, errorStream);
1606 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1607 return expectedResultType == getResult().getType();
1610 <<
"Could not match allowed types for the result; expected one of ";
1611 llvm::interleaveComma(expectedResult, errorStream);
1612 errorStream <<
" but got " << getResult().getType();
1620 if (!getIntOverflowBehavior())
1622 getIntOverflowBehaviorAttrName().strref() +
1627 if (!getSparseMetadata().
getType().isInteger(32)) {
1628 return emitOpError() <<
"sparse metadata must be i32 type";
1632 if (!getSparsitySelector().
getType().isInteger(32)) {
1633 return emitOpError() <<
"sparsity selector must be i32 type";
1645struct MMAOperandFragment {
1646 StringRef operandName;
1647 StringRef ptxTypeAttr;
1648 SmallVector<Value, 4> regs;
1649 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1650 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1657 p <<
" " << name <<
"[";
1676template <
typename Op>
1681 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1682 auto &frag = frags[fragIdx];
1683 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1684 for (
auto operandIdx = varOperandSpec.first;
1685 operandIdx < varOperandSpec.first + varOperandSpec.second;
1687 frag.regs.push_back(op.getOperand(operandIdx));
1688 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1689 regTypes.push_back(op.getOperand(operandIdx).getType());
1693 regTypes.push_back(frag.regs[0].getType());
1695 std::optional<MMATypes> inferredType =
1696 MmaOp::inferOperandMMAType(regTypes.back(),
1699 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1710 auto typeParser = [&]() {
1714 operandTypes.push_back(ty);
1720 if (operandTypes.size() != 3)
1722 "expected exactly 3 types");
1731 if (!attrs.
get(
"multiplicandAPtxType")) {
1732 if (
auto inferredType =
1733 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1734 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1737 if (!attrs.
get(
"multiplicandBPtxType")) {
1738 if (
auto inferredType =
1739 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1740 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1746template <
typename OpType>
1749 ScaleVecSize scaleVecSize,
1750 BlockScaleFormat blockScaleFormat,
1751 MMABlockScaleKind kind) {
1753 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1754 properties.setShape(
1756 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1757 properties.setBlockScaleFormat(
1758 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1759 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1766 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1767 if (multiplicandPtxTypes) {
1768 result.addAttribute(
"multiplicandAPtxType",
1769 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1770 result.addAttribute(
"multiplicandBPtxType",
1771 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1773 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1774 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1775 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1776 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1781template <
typename OpTy>
1783 return *MmaOp::inferOperandMMAType(
1784 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1794 std::array<MMAOperandFragment, 3> frags{
1795 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1796 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1797 MMAOperandFragment(
"C",
"")};
1799 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1804 for (
const auto &frag : frags)
1809 {getScaleAData(), getByteIdA(), getThreadIdA()});
1811 {getScaleBData(), getByteIdB(), getThreadIdB()});
1818 frags[1].regs[0].getType(),
1819 frags[2].regs[0].getType()},
1825ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1827 struct LocalOperandFragment {
1828 std::optional<MMATypes> elemtype;
1829 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1833 std::array<LocalOperandFragment, 3> frags;
1862 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1863 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1866 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1876 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1886 result.addAttributes(namedAttributes);
1890 result.addTypes(resultTypes);
1891 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1893 static_cast<int32_t>(frags[0].regs.size()),
1894 static_cast<int32_t>(frags[1].regs.size()),
1895 static_cast<int32_t>(frags[2].regs.size()),
1906void MmaBlockScaleOp::build(
1911 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1912 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1913 MMABlockScaleKind kind) {
1914 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1917 blockScaleFormat, kind);
1919 result.addOperands(operandA);
1920 result.addOperands(operandB);
1921 result.addOperands(operandC);
1923 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1926 multiplicandPtxTypes);
1928 result.addTypes(resultType);
1929 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1931 static_cast<int32_t>(operandA.size()),
1932 static_cast<int32_t>(operandB.size()),
1933 static_cast<int32_t>(operandC.size()),
1945 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1949 for (
Value operand : curOp.getOperandA())
1951 for (
Value operand : curOp.getOperandB())
1953 for (
Value operand : curOp.getOperandC())
1957 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1958 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1959 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1960 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1961 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1962 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1964 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1965 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1966 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1968 curOp.getBlockScaleFormat(), curOp.getKind());
1970 return {intId, args};
1973LogicalResult MmaBlockScaleOp::verify() {
1979 if (m == 16 && n == 8 && k == 64) {
1980 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1981 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1983 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1984 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1985 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1987 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1988 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1990 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1991 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1992 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1993 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1994 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1995 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1996 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1998 "attributes for mma.m16n8k64.mxf4nvf4");
2002 }
else if (m == 16 && n == 8 && k == 32) {
2003 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2004 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2005 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2007 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2008 "attributes for mma.m16n8k32");
2021 std::array<MMAOperandFragment, 3> frags{
2022 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
2023 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
2024 MMAOperandFragment(
"C",
"")};
2026 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
2031 for (
const auto &frag : frags)
2040 {getScaleAData(), getByteIdA(), getThreadIdA()});
2042 {getScaleBData(), getByteIdB(), getThreadIdB()});
2049 frags[1].regs[0].getType(),
2050 frags[2].regs[0].getType()},
2056ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
2058 struct LocalOperandFragment {
2059 std::optional<MMATypes> elemtype;
2060 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2064 std::array<LocalOperandFragment, 3> frags;
2100 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2101 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2104 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2113 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2126 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2136 result.addAttributes(namedAttributes);
2141 if (!
result.attributes.get(
"orderedMetadata"))
2144 result.addTypes(resultTypes);
2145 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2147 static_cast<int32_t>(frags[0].regs.size()),
2148 static_cast<int32_t>(frags[1].regs.size()),
2149 static_cast<int32_t>(frags[2].regs.size()),
2162void MmaSpBlockScaleOp::build(
2168 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2169 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2170 MMABlockScaleKind kind) {
2171 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2174 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2177 result.addOperands(operandA);
2178 result.addOperands(operandB);
2179 result.addOperands(operandC);
2180 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2181 threadIdA, scaleBData, byteIdB, threadIdB});
2184 multiplicandPtxTypes);
2186 result.addTypes(resultType);
2187 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2189 static_cast<int32_t>(operandA.size()),
2190 static_cast<int32_t>(operandB.size()),
2191 static_cast<int32_t>(operandC.size()),
2205 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2209 for (
Value operand : curOp.getOperandA())
2211 for (
Value operand : curOp.getOperandB())
2213 for (
Value operand : curOp.getOperandC())
2217 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2218 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2221 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2222 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2223 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2224 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2225 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2226 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2228 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2229 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2230 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2232 curOp.getBlockScaleFormat(), curOp.getKind());
2234 return {intId, args};
2237LogicalResult MmaSpBlockScaleOp::verify() {
2239 if (!getOrderedMetadata()) {
2240 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2248 if (m == 16 && n == 8 && k == 128) {
2249 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2250 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2252 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2253 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2254 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2256 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2257 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2259 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2260 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2261 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2262 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2263 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2264 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2265 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2267 "attributes for mma.m16n8k128.mxf4nvf4");
2271 }
else if (m == 16 && n == 8 && k == 64) {
2272 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2273 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2274 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2276 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2277 "attributes for mma.m16n8k64");
2284LogicalResult ShflOp::verify() {
2285 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2287 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2288 Type actualType) -> LogicalResult {
2289 return emitOpError(
"expected " + desc +
" to be of type ")
2290 << expectedType <<
" but got " << actualType <<
" instead";
2293 if (returnStructType) {
2294 if (!getReturnValueAndIsValid())
2295 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2296 "specified when the return type is a struct type");
2298 if (returnStructType.getBody().size() != 2)
2299 return emitOpError(
"expected return type to be a two-element struct");
2302 auto resultType = returnStruct[0];
2303 if (resultType != getVal().
getType())
2304 return verifyTypeError(
"first element in the returned struct",
2305 getVal().
getType(), resultType);
2307 auto predicateType = returnStruct[1];
2308 if (!predicateType.isInteger(1))
2309 return verifyTypeError(
"second element in the returned struct",
2313 if (getReturnValueAndIsValid())
2314 return emitOpError(
"expected return type to be a two-element struct");
2317 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2323ShflOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
2324 ShflOp::Adaptor adaptor,
2326 Type valType = adaptor.getVal().getType();
2327 if (adaptor.getReturnValueAndIsValid())
2328 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2329 context, {valType, IntegerType::get(context, 1)}));
2331 inferredReturnTypes.push_back(valType);
2336 NVVM::MMAFrag frag,
int nRow,
2339 unsigned numberElements = 0;
2342 Type f16x2 = VectorType::get(2, builder.getF16Type());
2343 if (type == NVVM::MMATypes::f16) {
2344 elementType = f16x2;
2345 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2349 }
else if (type == NVVM::MMATypes::f32) {
2350 elementType = builder.getF32Type();
2352 }
else if (type == NVVM::MMATypes::f64) {
2353 elementType = builder.getF64Type();
2354 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2358 }
else if (type == NVVM::MMATypes::tf32) {
2359 elementType = builder.getI32Type();
2361 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2362 elementType = builder.getI32Type();
2363 int parallelSize = 0;
2364 if (frag == NVVM::MMAFrag::a)
2365 parallelSize = nRow;
2366 if (frag == NVVM::MMAFrag::b)
2367 parallelSize = nCol;
2370 if (parallelSize == 16)
2373 else if (parallelSize == 8)
2375 else if (parallelSize == 32)
2377 }
else if (type == NVVM::MMATypes::s32) {
2378 elementType = builder.getI32Type();
2381 assert(numberElements != 0 && elementType !=
nullptr);
2382 return std::make_pair(elementType, numberElements);
2385static std::pair<mlir::Type, unsigned>
2389 if (frag == NVVM::MMAFrag::a) {
2392 }
else if (frag == NVVM::MMAFrag::b) {
2399 assert(nRow && nCol);
2403LogicalResult NVVM::WMMALoadOp::verify() {
2404 unsigned addressSpace =
2405 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2406 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2407 addressSpace != NVVMMemorySpace::Shared)
2408 return emitOpError(
"expected source pointer in memory "
2411 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2412 getEltype(), getFrag()) == 0)
2413 return emitOpError() <<
"invalid attribute combination";
2418 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2420 return emitOpError(
"expected destination type to be f64");
2424 Type dstType = LLVM::LLVMStructType::getLiteral(
2427 return emitOpError(
"expected destination type is a structure of ")
2428 << typeInfo.second <<
" elements of type " << typeInfo.first;
2432LogicalResult NVVM::WMMAStoreOp::verify() {
2433 unsigned addressSpace =
2434 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2435 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2436 addressSpace != NVVMMemorySpace::Shared)
2437 return emitOpError(
"expected operands to be a source pointer in memory "
2440 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2442 return emitOpError() <<
"invalid attribute combination";
2445 if (getArgs().size() != typeInfo.second)
2446 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2447 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2448 return operands.
getType() != typeInfo.first;
2450 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2454LogicalResult NVVM::WMMAMmaOp::verify() {
2455 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2456 getLayoutB(), getEltypeA(),
2458 return emitOpError() <<
"invalid attribute combination";
2466 arguments.append(typeInfoA.second, typeInfoA.first);
2467 arguments.append(typeInfoB.second, typeInfoB.first);
2468 arguments.append(typeInfoC.second, typeInfoC.first);
2469 unsigned numArgs = arguments.size();
2470 if (getArgs().size() != numArgs)
2471 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2472 for (
unsigned i = 0; i < numArgs; i++) {
2473 if (getArgs()[i].
getType() != arguments[i])
2474 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2477 Type dstType = LLVM::LLVMStructType::getLiteral(
2480 return emitOpError(
"expected destination type is a structure of ")
2481 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2485LogicalResult NVVM::LdMatrixOp::verify() {
2487 if (m == 8 && n == 8) {
2488 if (num != 1 && num != 2 && num != 4) {
2489 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2492 if (getEltType() != LdStMatrixEltType::B16) {
2493 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2495 }
else if (m == 8 && n == 16) {
2496 if (num != 1 && num != 2 && num != 4) {
2497 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2500 if (getLayout() != MMALayout::row) {
2501 return emitOpError(
"expected layout to be row for 8x16 matrix");
2503 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2504 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2505 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2506 "b8x16.b6x16_p32 for 8x16 matrix");
2508 }
else if (m == 16 && n == 16) {
2509 if (num != 1 && num != 2) {
2510 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2513 if (getLayout() != MMALayout::col) {
2514 return emitOpError(
"expected layout to be col for 16x16 matrix");
2516 if (getEltType() != LdStMatrixEltType::B8 &&
2517 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2518 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2519 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2520 "b8x16.b6x16_p32 for 16x16 matrix");
2523 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2527 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2528 if (numElements == 1 &&
getType() != i32)
2529 return emitOpError(
"expected destination type is i32");
2530 if (numElements == 2 || numElements == 4) {
2531 Type dstType = LLVM::LLVMStructType::getLiteral(
2534 return emitOpError(
"expected destination type is a structure of ")
2535 << numElements <<
" elements of type i32";
2541LogicalResult LdMatrixOp::inferReturnTypes(
2542 MLIRContext *context, std::optional<Location> location,
2544 uint32_t num = adaptor.getNum();
2545 uint32_t m = adaptor.getShape().getM();
2546 uint32_t n = adaptor.getShape().getN();
2547 uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
2549 Type i32 = IntegerType::get(context, 32);
2550 if (numElements == 1)
2551 inferredReturnTypes.push_back(i32);
2553 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2558LogicalResult NVVM::StMatrixOp::verify() {
2559 int numMatrix = getSources().size();
2560 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2561 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2564 if (m == 8 && n == 8) {
2565 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2566 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2568 }
else if (m == 16 && n == 8) {
2569 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2570 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2572 if (getLayout() != NVVM::MMALayout::col) {
2573 return emitOpError(
"expected layout to be col for 16x8 matrix");
2576 return emitOpError(
"expected shape to be 8x8 or 16x8");
2582LogicalResult NVVM::MovMatrixOp::verify() {
2584 if (m != 8 || n != 8)
2586 if (getLayout() != NVVM::MMALayout::col)
2588 if (getEltType() != NVVM::LdStMatrixEltType::B16)
2589 return emitOpError(
"expected element type to be b16");
2594 if (typeA == NVVM::WGMMATypes::tf32)
2596 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2598 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2600 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2602 if (typeA == NVVM::WGMMATypes::b1)
2608 NVVM::WGMMATypes typeA,
2609 NVVM::WGMMATypes typeB) {
2611 case NVVM::WGMMATypes::f16:
2612 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2613 typeB == NVVM::WGMMATypes::f16)
2616 case NVVM::WGMMATypes::tf32:
2617 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2620 case NVVM::WGMMATypes::u8:
2621 case NVVM::WGMMATypes::s8:
2622 if (typeD == NVVM::WGMMATypes::s32 &&
2623 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2626 case NVVM::WGMMATypes::b1:
2627 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2630 case NVVM::WGMMATypes::bf16:
2631 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2632 typeB == NVVM::WGMMATypes::bf16)
2635 case NVVM::WGMMATypes::e4m3:
2636 case NVVM::WGMMATypes::e5m2:
2637 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2638 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2641 case WGMMATypes::f32:
2642 case WGMMATypes::s32:
2643 llvm_unreachable(
"unsupported input types");
2651 72, 80, 88, 96, 104, 112, 120, 128,
2652 136, 144, 152, 160, 168, 176, 184, 192,
2653 200, 208, 216, 224, 232, 240, 248, 256};
2655 80, 96, 112, 128, 144, 160,
2656 176, 192, 208, 224, 240, 256};
2658 case WGMMATypes::f16:
2659 case WGMMATypes::tf32:
2660 case WGMMATypes::bf16:
2661 case WGMMATypes::e4m3:
2662 case WGMMATypes::e5m2:
2663 if (llvm::is_contained(allowedN, sizeN))
2666 case WGMMATypes::u8:
2667 case WGMMATypes::s8:
2668 case WGMMATypes::b1:
2669 if (llvm::is_contained(allowedNshort, sizeN))
2672 case WGMMATypes::f32:
2673 case WGMMATypes::s32:
2674 llvm_unreachable(
"unsupported input types");
2680LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2681 Value outValue = getResults();
2682 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2684 return emitOpError() <<
"expected results to be struct";
2685 int outputSize = stype.getBody().size();
2686 WGMMATypes typeD = getTypeD();
2687 WGMMATypes typeA = getTypeA();
2688 WGMMATypes typeB = getTypeB();
2690 for (
Type t : stype.getBody()) {
2691 if (t != stype.getBody().front())
2693 <<
"all elements in struct must be same type but there is " << t;
2696 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2697 typeD != WGMMATypes::s32) {
2698 return emitOpError() <<
"does not support the given output type " << typeD;
2700 if (typeD == WGMMATypes::s32 &&
2701 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2702 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2706 return emitOpError() << typeD <<
" += " << typeA <<
" * " << typeB
2707 <<
", it is not supported.";
2717 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2718 <<
" for input type " << typeA;
2722 return emitOpError() <<
"has input type " << typeA <<
" n is set to "
2723 <<
getShape().getN() <<
", it is not supported.";
2730 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2731 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2732 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2734 <<
"given layouts layout_a = " << getLayoutA()
2735 <<
" and layout_b = " << getLayoutB() <<
" for input types " << typeA
2737 <<
" requires transpose. However, this is only supported for: "
2738 << MMATypes::f16 <<
" and " << MMATypes::bf16;
2742 int expectedOutput = 0;
2743 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2744 expectedOutput =
getShape().getN() / 2;
2745 if (typeD == WGMMATypes::f16)
2746 expectedOutput =
getShape().getN() / 4;
2747 if (outputSize != expectedOutput) {
2748 return emitOpError() <<
"results " << expectedOutput
2749 <<
", however output struct has " << outputSize
2753 if (typeD != WGMMATypes::s32 &&
2754 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2755 NVVM::MMAIntOverflow::satfinite) {
2757 <<
" `satfinite` can be only used with s32 accumulator, however "
2758 "the current accumulator is "
2765std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2768 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2770 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2772 int expectedOutputRegisters = 0;
2773 if (getTypeD() == WGMMATypes::f16)
2774 expectedOutputRegisters =
getShape().getN() / 4;
2776 expectedOutputRegisters =
getShape().getN() / 2;
2779 llvm::raw_string_ostream ss(ptx);
2784 << ((expectedOutputRegisters * 2) + 2)
2786 "wgmma.mma_async.sync.aligned.m"
2787 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"." << getTypeA()
2788 <<
"." << getTypeB();
2789 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2790 NVVM::MMAIntOverflow::satfinite)
2794 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2795 ss <<
"$" << regCnt;
2796 if (regCnt != expectedOutputRegisters - 1)
2802 regCnt = (regCnt * 2);
2803 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
2804 if (getTypeD() != WGMMATypes::s32) {
2805 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2809 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2816bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2820 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2827 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2829 if (getTypeD() != WGMMATypes::s32) {
2830 asmValues.push_back(
2831 {makeConstantI32(rewriter,
2832 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2834 asmValues.push_back(
2835 {makeConstantI32(rewriter,
2836 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2840 asmValues.push_back(
2841 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2843 asmValues.push_back(
2844 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2850LogicalResult NVVM::FenceProxyOp::verify() {
2851 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2852 return emitOpError() <<
"async_shared fence requires space attribute";
2854 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2855 return emitOpError() <<
"only async_shared fence can have space attribute";
2860LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2861 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2862 return emitOpError(
"uni-directional proxies only support generic for "
2863 "from_proxy attribute");
2865 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2866 return emitOpError(
"uni-directional proxies only support tensormap "
2867 "for to_proxy attribute");
2871LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2872 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2873 return emitOpError(
"uni-directional proxies only support generic for "
2874 "from_proxy attribute");
2876 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2877 return emitOpError(
"uni-directional proxies only support tensormap "
2878 "for to_proxy attribute");
2882LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2883 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2884 return emitOpError(
"only generic is support for from_proxy attribute");
2886 if (getToProxy() != NVVM::ProxyKind::async)
2887 return emitOpError(
"only async is supported for to_proxy attribute");
2891LogicalResult NVVM::SetMaxRegisterOp::verify() {
2892 if (getRegCount() % 8)
2893 return emitOpError(
"new register size must be multiple of 8");
2894 if (getRegCount() < 24 || getRegCount() > 256)
2895 return emitOpError(
"new register size must be in between 24 to 256");
2899LogicalResult NVVM::Tcgen05CpOp::verify() {
2900 auto mc = getMulticast();
2902 using SH = Tcgen05CpShape;
2903 using MC = Tcgen05CpMulticast;
2905 case SH::SHAPE_128x256b:
2906 case SH::SHAPE_128x128b:
2907 case SH::SHAPE_4x256b:
2909 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2911 case SH::SHAPE_64x128b:
2912 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2913 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2914 "warpx2_02_13 for tcgen05.cp Op");
2916 case SH::SHAPE_32x128b:
2917 if (mc != MC::WARPX4)
2919 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2925LogicalResult NVVM::MatchSyncOp::verify() {
2926 if (getKind() == NVVM::MatchSyncKind::all) {
2927 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2928 if (!type || type.getBody().size() != 2 ||
2929 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2930 return emitOpError(
"match.sync 'all' returns a two element struct with "
2931 "first element as i32 and second element as i1");
2934 if (!
getType().isInteger(32)) {
2935 return emitOpError(
"match.sync 'any' returns an i32");
2941LogicalResult MatchSyncOp::inferReturnTypes(
2942 MLIRContext *context, std::optional<Location> location,
2944 if (adaptor.getKind() == NVVM::MatchSyncKind::all)
2945 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2947 {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
2949 inferredReturnTypes.push_back(IntegerType::get(context, 32));
2953LogicalResult NVVM::VoteSyncOp::verify() {
2954 if (getKind() == NVVM::VoteSyncKind::ballot) {
2955 if (!
getType().isInteger(32)) {
2956 return emitOpError(
"vote.sync 'ballot' returns an i32");
2959 if (!
getType().isInteger(1)) {
2960 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
2966LogicalResult VoteSyncOp::inferReturnTypes(
2967 MLIRContext *context, std::optional<Location> location,
2969 unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
2970 inferredReturnTypes.push_back(IntegerType::get(context, width));
2974LogicalResult NVVM::PrefetchOp::verify() {
2975 using MemSpace = NVVM::NVVMMemorySpace;
2976 using CacheLevel = NVVM::PrefetchCacheLevel;
2978 unsigned addressSpace =
2979 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
2980 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2981 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2983 if (getTensormap() && cacheLevel)
2984 return emitOpError(
"cannot specify both tensormap and cache level");
2986 if (getTensormap()) {
2987 if (addressSpace != MemSpace::Generic &&
2988 addressSpace != MemSpace::Constant) {
2990 "prefetch tensormap requires a generic or constant pointer");
2993 if (evictPriority) {
2995 "prefetch tensormap does not support eviction priority");
2998 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
3000 "in_param_space can only be specified for a generic pointer");
3003 }
else if (cacheLevel) {
3004 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
3005 addressSpace != MemSpace::Local) {
3006 return emitOpError(
"prefetch to cache level requires a generic, global, "
3007 "or local pointer");
3011 if (*cacheLevel != CacheLevel::L1) {
3013 "unsupported cache level, the only supported uniform "
3014 "cache level is L1");
3017 if (addressSpace != MemSpace::Generic) {
3019 "prefetch to uniform cache requires a generic pointer");
3023 if (evictPriority) {
3024 if (*cacheLevel != CacheLevel::L2)
3026 "cache eviction priority supported only for cache level L2");
3028 if (addressSpace != MemSpace::Global)
3029 return emitOpError(
"cache eviction priority requires a global pointer");
3031 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
3032 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
3034 "unsupported cache eviction priority, only evict_last and "
3035 "evict_normal are supported");
3039 return emitOpError(
"predicate supported only on prefetch tensormap");
3043 "requires specification of either cache level or tensormap");
3049LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
3050 switch (getQueryType()) {
3051 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3053 return emitOpError(
"is_canceled query type returns an i1");
3055 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3056 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3057 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3058 if (!
getType().isInteger(32)) {
3059 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
3060 "get_first_cta_id_z query types return an i32");
3067LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
3068 MLIRContext *context, std::optional<Location> location,
3069 ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
3072 adaptor.getQueryType() == NVVM::ClusterLaunchControlQueryType::IS_CANCELED
3075 inferredReturnTypes.push_back(IntegerType::get(context, width));
3079LogicalResult NVVM::ReduxOp::verify() {
3082 if (!reduxType.
isF32()) {
3084 return emitOpError(
"abs attribute is supported only for f32 type");
3086 return emitOpError(
"nan attribute is supported only for f32 type");
3089 NVVM::ReductionKind kind = getKind();
3091 case NVVM::ReductionKind::ADD:
3092 case NVVM::ReductionKind::AND:
3093 case NVVM::ReductionKind::OR:
3094 case NVVM::ReductionKind::XOR:
3095 case NVVM::ReductionKind::MAX:
3096 case NVVM::ReductionKind::MIN:
3097 case NVVM::ReductionKind::UMAX:
3098 case NVVM::ReductionKind::UMIN:
3101 << kind <<
"' reduction kind unsupported with " << reduxType
3102 <<
" type. Only supported type is 'i32'.";
3104 case NVVM::ReductionKind::FMIN:
3105 case NVVM::ReductionKind::FMAX:
3106 if (!reduxType.isF32())
3108 << kind <<
"' reduction kind unsupported with " << reduxType
3109 <<
" type. Only supported type is 'f32'.";
3116LogicalResult NVVM::TensormapReplaceOp::verify() {
3117 auto ord = getOrd();
3118 Value newVal = getNewValue();
3119 auto newValAttr = getNewValueAttr();
3120 auto fieldName = stringifyEnum(getField());
3122 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3123 NVVM::TensormapField::GLOBAL_DIM,
3124 NVVM::TensormapField::GLOBAL_STRIDE,
3125 NVVM::TensormapField::ELEMENT_STRIDE},
3127 return emitOpError(
"ordinal is not supported for ")
3128 << fieldName <<
" field";
3130 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3131 return llvm::Twine(
"new_value must be specified and must be an " + type +
3132 " for " + llvm::Twine(fieldName) +
" field")
3136 auto invalidNewValAttr = [&]() -> std::string {
3137 return (llvm::Twine(
3138 "new_value_attr must be specified and must be a valid ") +
3139 llvm::Twine(fieldName) +
" attribute for " + fieldName +
" field")
3143 switch (getField()) {
3144 case NVVM::TensormapField::GLOBAL_ADDRESS:
3148 case NVVM::TensormapField::RANK:
3152 case NVVM::TensormapField::GLOBAL_STRIDE:
3154 return emitOpError(
"ordinal is required for global_stride field");
3158 case NVVM::TensormapField::BOX_DIM:
3159 case NVVM::TensormapField::GLOBAL_DIM:
3160 case NVVM::TensormapField::ELEMENT_STRIDE:
3163 << stringifyEnum(getField()) <<
" field";
3167 case NVVM::TensormapField::ELEMTYPE:
3168 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3171 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3172 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3175 case NVVM::TensormapField::SWIZZLE_MODE:
3176 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3179 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3180 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3183 case NVVM::TensormapField::FILL_MODE:
3184 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3192template <
typename OpType>
3194 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3195 mlir::NVVM::SaturationMode satMode = op.getSat();
3196 bool isFTZ = op.getFtz();
3199 mlir::Type opBaseType = isa<VectorType>(opType)
3200 ? cast<VectorType>(opType).getElementType()
3203 if (opBaseType.
isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3204 return op.emitOpError(
"FTZ and saturation are not supported for "
3205 "additions/subtractions involving f64 type");
3207 if (opBaseType.
isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3208 rndMode == NVVM::FPRoundingMode::NONE))
3209 return op.emitOpError(
"only RN rounding mode is supported for f16 and "
3210 "vector<2xf16> additions/subtractions");
3212 if (opBaseType.
isBF16()) {
3213 if (rndMode != NVVM::FPRoundingMode::RN &&
3214 rndMode != NVVM::FPRoundingMode::NONE)
3215 return op.emitOpError(
"only RN rounding mode is supported for bf16 and "
3216 "vector<2xbf16> additions/subtractions");
3217 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3218 return op.emitOpError(
"FTZ and saturation are not supported for bf16 and "
3219 "vector<2xbf16> additions/subtractions");
3226 if (opBaseType.
isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3227 return op.emitOpError(
"FTZ with no saturation is not supported for f16 and "
3228 "vector<2xf16> additions/subtractions");
3237LogicalResult NVVM::FmaOp::verify() {
3238 auto opType = getRes().getType();
3239 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3240 mlir::NVVM::SaturationMode satMode = getSat();
3241 bool isFTZ = getFtz();
3242 bool isRelu = getRelu();
3243 bool hasOOB = getOob();
3245 auto getBaseFType = [](
Type type) ->
Type {
3246 if (isa<VectorType>(type))
3247 return cast<VectorType>(type).getElementType();
3251 auto opBaseType = getBaseFType(opType);
3253 if (rndMode == NVVM::FPRoundingMode::NONE)
3254 return emitOpError(
"rounding mode must be specified");
3256 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3257 return emitOpError(
"relu and saturation are not supported together");
3259 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3260 return emitOpError(
"oob is not supported with saturation or FTZ");
3262 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3263 return emitOpError(
"relu and oob are only supported for f16 and bf16");
3265 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3266 return emitOpError(
"FTZ and saturation are not supported for f64 type");
3268 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3270 "only RN rounding mode is supported for f16 and vector<2xf16>");
3272 if (opBaseType.isBF16()) {
3273 if (rndMode != NVVM::FPRoundingMode::RN)
3275 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3276 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3278 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3284LogicalResult NVVM::SqrtOp::verify() {
3285 if (getRnd() == NVVM::FPRoundingMode::NONE)
3286 return emitOpError(
"rounding mode cannot be None");
3288 if (getRes().
getType().isF64() && getFtz())
3289 return emitOpError(
"FTZ is not supported for f64");
3294LogicalResult NVVM::DivFOp::verify() {
3295 bool isApprox = getApprox();
3296 bool isFull = getFull();
3297 bool isF64 = getRes().getType().isF64();
3298 bool isFtz = getFtz();
3299 NVVM::FPRoundingMode rndMode = getRnd();
3301 if (isApprox && isFull)
3302 return emitOpError(
"'approx' and 'full' are mutually exclusive");
3304 if (isApprox || isFull) {
3306 return emitOpError(
"'approx' and 'full' forms are f32-only");
3307 if (rndMode != NVVM::FPRoundingMode::NONE)
3309 "'approx' and 'full' forms do not accept a rounding mode");
3314 if (rndMode == NVVM::FPRoundingMode::NONE)
3315 return emitOpError(
"rounding mode cannot be None for the rounded divide");
3317 return emitOpError(
"FTZ is not supported for f64");
3328 unsigned sizeInBits,
3330 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3332 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3333 if (mask != 0xffffffffu)
3334 field = builder.CreateAnd(field, builder.getInt32(mask));
3336 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3337 field = builder.CreateShl(field, start);
3339 return builder.CreateOr(
result, field);
3342void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3344 llvm::IRBuilderBase &builder) {
3345 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3346 llvm::Value *smemDesc = builder.getInt64(0);
3351 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3353 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3359 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3363 mt.
mapValue(thisOp.getRes()) = smemDesc;
3370std::string NVVM::MBarrierInitOp::getPtx() {
3372 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3373 : std::string(
"mbarrier.init.b64 [%0], %1;");
3376std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3379 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3380 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3383std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3385 llvm::StringRef space = isShared ?
".shared" :
"";
3387 return llvm::formatv(
"{\n\t"
3388 ".reg .pred P1; \n\t"
3390 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3391 "@P1 bra.uni DONE; \n\t"
3392 "bra.uni LAB_WAIT; \n\t"
3409 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3412 op.getRnd(), op.getSat(), op.getFtz());
3431 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count
3432 : llvm::Intrinsic::nvvm_barrier_cta_sync_count;
3434 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all
3435 : llvm::Intrinsic::nvvm_barrier_cta_sync_all;
3440static llvm::Intrinsic::ID
3443 case NVVM::BarrierReduction::AND:
3444 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all
3445 : llvm::Intrinsic::nvvm_barrier_cta_red_and_all;
3446 case NVVM::BarrierReduction::OR:
3447 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all
3448 : llvm::Intrinsic::nvvm_barrier_cta_red_or_all;
3449 case NVVM::BarrierReduction::POPC:
3450 return aligned ? llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all
3451 : llvm::Intrinsic::nvvm_barrier_cta_red_popc_all;
3453 llvm_unreachable(
"unknown BarrierReduction kind");
3458 auto thisOp = cast<NVVM::BarrierOp>(op);
3459 llvm::Value *barrierId = thisOp.getBarrierId()
3461 : builder.getInt32(0);
3462 bool hasCount =
static_cast<bool>(thisOp.getNumberOfThreads());
3463 llvm::Intrinsic::ID
id =
3467 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3468 return {id, std::move(args)};
3473 auto thisOp = cast<NVVM::BarrierArriveOp>(op);
3474 llvm::Value *barrierId = thisOp.getBarrierId()
3476 : builder.getInt32(0);
3477 llvm::Value *numThreads = mt.
lookupValue(thisOp.getNumberOfThreads());
3478 llvm::Intrinsic::ID
id =
3480 ? llvm::Intrinsic::nvvm_barrier_cta_arrive_aligned_count
3481 : llvm::Intrinsic::nvvm_barrier_cta_arrive_count;
3482 return {id, {barrierId, numThreads}};
3487 auto thisOp = cast<NVVM::BarrierReductionOp>(op);
3489 thisOp.getAligned(), thisOp.getReductionOp());
3490 llvm::Value *barrierId = thisOp.getBarrierId()
3492 : builder.getInt32(0);
3495 builder.CreateICmpNE(mt.
lookupValue(thisOp.getReductionPredicate()),
3496 builder.getInt32(0))};
3497 return {id, std::move(args)};
3502 llvm::IRBuilderBase &builder) {
3503 auto thisOp = cast<NVVM::CosOp>(op);
3504 llvm::Intrinsic::ID
id = thisOp.getFtz()
3505 ? llvm::Intrinsic::nvvm_cos_approx_ftz_f
3506 : llvm::Intrinsic::nvvm_cos_approx_f;
3512 llvm::IRBuilderBase &builder) {
3513 auto thisOp = cast<NVVM::SinOp>(op);
3514 llvm::Intrinsic::ID
id = thisOp.getFtz()
3515 ? llvm::Intrinsic::nvvm_sin_approx_ftz_f
3516 : llvm::Intrinsic::nvvm_sin_approx_f;
3522 llvm::IRBuilderBase &builder) {
3523 auto thisOp = cast<NVVM::Log2Op>(op);
3524 llvm::Intrinsic::ID
id = thisOp.getFtz()
3525 ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
3526 : llvm::Intrinsic::nvvm_lg2_approx_f;
3532 llvm::IRBuilderBase &builder) {
3533 auto thisOp = cast<NVVM::Ex2Op>(op);
3534 llvm::Intrinsic::ID
id = thisOp.getFtz()
3535 ? llvm::Intrinsic::nvvm_ex2_approx_ftz
3536 : llvm::Intrinsic::nvvm_ex2_approx;
3542 llvm::IRBuilderBase &builder) {
3543 auto thisOp = cast<NVVM::RsqrtOp>(op);
3544 Type t = thisOp.getRes().getType();
3545 bool isFtz = thisOp.getFtz();
3547 llvm::Intrinsic::ID
id = [&] {
3549 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
3550 : llvm::Intrinsic::nvvm_rsqrt_approx_f;
3553 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
3554 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
3562 llvm::IRBuilderBase &builder) {
3563 auto thisOp = cast<NVVM::SqrtOp>(op);
3564 Type t = thisOp.getRes().getType();
3565 NVVM::FPRoundingMode rndMode = thisOp.getRnd();
3566 bool isFtz = thisOp.getFtz();
3570 unsigned rndIndex =
static_cast<unsigned>(rndMode) - 1;
3572 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3573 llvm::Intrinsic::nvvm_sqrt_rn_f,
3574 llvm::Intrinsic::nvvm_sqrt_rm_f,
3575 llvm::Intrinsic::nvvm_sqrt_rp_f,
3576 llvm::Intrinsic::nvvm_sqrt_rz_f,
3578 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3579 llvm::Intrinsic::nvvm_sqrt_rn_ftz_f,
3580 llvm::Intrinsic::nvvm_sqrt_rm_ftz_f,
3581 llvm::Intrinsic::nvvm_sqrt_rp_ftz_f,
3582 llvm::Intrinsic::nvvm_sqrt_rz_ftz_f,
3584 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3585 llvm::Intrinsic::nvvm_sqrt_rn_d,
3586 llvm::Intrinsic::nvvm_sqrt_rm_d,
3587 llvm::Intrinsic::nvvm_sqrt_rp_d,
3588 llvm::Intrinsic::nvvm_sqrt_rz_d,
3591 llvm::Intrinsic::ID
id =
3592 t.
isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3600 llvm::IRBuilderBase &builder) {
3601 auto thisOp = cast<NVVM::SqrtApproxOp>(op);
3602 llvm::Intrinsic::ID
id = thisOp.getFtz()
3603 ? llvm::Intrinsic::nvvm_sqrt_approx_ftz_f
3604 : llvm::Intrinsic::nvvm_sqrt_approx_f;
3610 llvm::IRBuilderBase &builder) {
3611 auto thisOp = cast<NVVM::DivFOp>(op);
3612 bool isFtz = thisOp.getFtz();
3614 llvm::Intrinsic::ID id;
3616 if (thisOp.getApprox()) {
3617 id = isFtz ? llvm::Intrinsic::nvvm_div_approx_ftz_f
3618 : llvm::Intrinsic::nvvm_div_approx_f;
3619 }
else if (thisOp.getFull()) {
3622 id = isFtz ? llvm::Intrinsic::nvvm_div_full_ftz
3623 : llvm::Intrinsic::nvvm_div_full;
3626 unsigned rndIndex =
static_cast<unsigned>(thisOp.getRnd()) - 1;
3628 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3629 llvm::Intrinsic::nvvm_div_rn_f,
3630 llvm::Intrinsic::nvvm_div_rm_f,
3631 llvm::Intrinsic::nvvm_div_rp_f,
3632 llvm::Intrinsic::nvvm_div_rz_f,
3634 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3635 llvm::Intrinsic::nvvm_div_rn_ftz_f,
3636 llvm::Intrinsic::nvvm_div_rm_ftz_f,
3637 llvm::Intrinsic::nvvm_div_rp_ftz_f,
3638 llvm::Intrinsic::nvvm_div_rz_ftz_f,
3640 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3641 llvm::Intrinsic::nvvm_div_rn_d,
3642 llvm::Intrinsic::nvvm_div_rm_d,
3643 llvm::Intrinsic::nvvm_div_rp_d,
3644 llvm::Intrinsic::nvvm_div_rz_d,
3646 Type t = thisOp.getRes().getType();
3647 id = t.
isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3657 llvm::IRBuilderBase &builder) {
3658 auto thisOp = cast<NVVM::PMEventOp>(op);
3662 llvm::Value *maskVal;
3663 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3664 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3665 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3668 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3671 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3676 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3678 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3679 : llvm::Intrinsic::nvvm_mbarrier_init;
3684 args.push_back(mt.
lookupValue(thisOp.getCount()));
3686 return {id, std::move(args)};
3691 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3693 llvm::Intrinsic::ID
id = isShared
3694 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3695 : llvm::Intrinsic::nvvm_mbarrier_inval;
3702 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3705 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3708 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3710 static constexpr llvm::Intrinsic::ID IDs[] = {
3711 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3712 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3713 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3714 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3719 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3721 return {IDs[
index], std::move(args)};
3726 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3729 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3732 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3734 static constexpr llvm::Intrinsic::ID IDs[] = {
3735 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3736 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3737 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3738 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3743 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3745 return {IDs[
index], std::move(args)};
3750 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3753 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3756 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3758 static constexpr llvm::Intrinsic::ID IDs[] = {
3759 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3760 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3761 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3762 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3763 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3764 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3765 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3766 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3768 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3769 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3773 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3780 bool hasCount =
static_cast<bool>(thisOp.getCount());
3782 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3783 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3787 llvm::Value *count =
3789 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3790 return {id, {mbar, count}};
3795 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3798 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3801 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3803 static constexpr llvm::Intrinsic::ID IDs[] = {
3804 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3805 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3806 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3807 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3808 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3809 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3811 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3813 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3815 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3816 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3820 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3826 bool hasCount =
static_cast<bool>(thisOp.getCount());
3827 llvm::Value *count =
3829 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3831 return {id, {mbar, count}};
3834bool MBarrierArriveExpectTxOp::getAsmValues(
3841 for (
auto val : getOperands())
3849 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3852 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3855 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3858 static constexpr llvm::Intrinsic::ID IDs[] = {
3859 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3860 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3861 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3862 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3863 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3864 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3865 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3866 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3867 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3869 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3872 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3873 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3878 return {id, {mbar, txcount}};
3883 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3886 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3889 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3892 static constexpr llvm::Intrinsic::ID IDs[] = {
3893 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3894 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3895 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3896 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3897 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3898 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3899 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3900 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3901 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3903 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3906 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3907 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3912 return {id, {mbar, txcount}};
3917 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3919 llvm::Intrinsic::ID
id =
3920 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3921 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3925 args.push_back(mt.
lookupValue(thisOp.getCount()));
3927 return {id, std::move(args)};
3932 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3934 llvm::Intrinsic::ID
id =
3935 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3936 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3940 args.push_back(mt.
lookupValue(thisOp.getCount()));
3942 return {id, std::move(args)};
3947 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3948 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3949 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3952 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3955 static constexpr llvm::Intrinsic::ID IDs[] = {
3956 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3957 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3958 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3959 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3960 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3961 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3962 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3963 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3964 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3966 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3969 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3970 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3975 return {id, {mbar, input}};
3980 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3981 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3982 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3983 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3987 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3988 (isPhaseParity ? 1 : 0);
3991 static constexpr llvm::Intrinsic::ID IDs[] = {
3992 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3993 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3994 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3995 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3996 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3997 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3998 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3999 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
4000 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
4001 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
4002 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
4003 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
4004 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
4005 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
4006 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
4007 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
4008 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
4010 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
4013 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
4020 args.push_back(mbar);
4021 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
4023 args.push_back(mt.
lookupValue(thisOp.getTicks()));
4025 return {id, std::move(args)};
4030 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
4033 llvm::Intrinsic::ID id;
4034 if (thisOp.getNoinc()) {
4035 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
4036 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
4038 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
4039 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
4047 llvm::IRBuilderBase &builder) {
4048 auto thisOp = cast<NVVM::MovMatrixOp>(op);
4049 return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
4053#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
4054 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
4056#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
4057 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
4062 llvm::Intrinsic::ID id;
4064 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
4065 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
4066 switch (cpAsyncOp.getSize()) {
4074 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
4079 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
4083 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
4084 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
4086 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
4093 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
4095 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
4098 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4102 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4103 llvm::Value *i64Unused =
4104 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4105 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4106 args.push_back(builder.getInt1(hasCacheHint));
4108 return {id, std::move(args)};
4113 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
4117 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4119 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4123 mlir::Value multicastMask = thisOp.getMulticastMask();
4124 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
4127 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
4128 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
4134 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4135 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
4136 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4140 args.push_back(builder.getInt1(hasMulticastMask));
4141 args.push_back(builder.getInt1(hasCacheHint));
4143 llvm::Intrinsic::ID
id =
4145 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
4146 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
4148 return {id, std::move(args)};
4153 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
4155 llvm::Intrinsic::ID
id =
4156 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
4159 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4160 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4164 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4165 llvm::Value *i64Unused =
4166 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4167 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4168 args.push_back(builder.getInt1(hasCacheHint));
4171 if (
mlir::Value byteMask = thisOp.getByteMask()) {
4173 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
4176 return {id, std::move(args)};
4179bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
4186 for (
auto val : getOperands())
4193CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4195 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
4196 const bool isCTAOnly = thisOp.getIsCTAOnly();
4200 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4202 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4212 const bool hasMC =
static_cast<bool>(mcMask);
4213 llvm::Value *i16Zero =
4214 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
4218 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4219 llvm::Value *i64Zero =
4220 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4226 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
4228 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
4232 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
4233 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
4234 args.push_back(builder.getInt1(hasMC));
4235 args.push_back(builder.getInt1(hasCacheHint));
4239 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
4240 args.push_back(builder.getInt1(hasCacheHint));
4243 constexpr size_t numDims = 5;
4244 constexpr size_t numModes = 5;
4245 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
4246 using TableTy = std::array<rowTy, numModes>;
4247 static constexpr TableTy IDTable{
4248 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
4249 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
4250 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
4251 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
4252 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
4254 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
4255 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
4256 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
4258 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
4259 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
4260 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
4262 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
4263 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
4264 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
4266 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
4268 static constexpr TableTy IDTableCTA{
4270 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
4271 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
4272 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
4273 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
4274 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
4276 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
4277 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
4278 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
4280 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
4281 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
4282 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
4284 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
4285 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
4286 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
4288 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
4291 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
4292 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
4293 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
4294 size_t mode =
static_cast<size_t>(thisOp.getMode());
4295 size_t dim = thisOp.getCoordinates().size();
4296 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
4298 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
4300 return {id, std::move(args)};
4305 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
4309 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4311 for (
auto v : thisOp.getCoordinates())
4313 for (
auto v : thisOp.getIm2colOffsets())
4317 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4318 llvm::Value *i64Unused =
4319 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4320 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4321 args.push_back(builder.getInt1(hasCacheHint));
4323 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4324 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4325 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
4326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
4327 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
4328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
4329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
4331 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
4332 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
4333 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
4335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
4336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
4337 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
4339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
4340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
4341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
4342 {NI, NI, NI, NI, NI,
4343 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
4345 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4346 "TMALoadModes must match number of rows in IDTable");
4347 size_t mode =
static_cast<size_t>(thisOp.getMode());
4348 size_t dim = thisOp.getCoordinates().size();
4349 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4350 if (
id == llvm::Intrinsic::not_intrinsic)
4351 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4353 return {id, std::move(args)};
4357CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4359 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4363 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4364 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4366 for (
auto v : thisOp.getCoordinates())
4370 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4371 llvm::Value *i64Unused =
4372 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4373 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4374 args.push_back(builder.getInt1(hasCacheHint));
4376 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4377 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4378 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4381 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4382 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4383 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4384 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4385 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4386 {NI, NI, NI, NI, NI,
4387 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4389 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4390 "TMAStoreModes must match number of rows in IDTable");
4391 size_t mode =
static_cast<size_t>(thisOp.getMode());
4392 size_t dim = thisOp.getCoordinates().size();
4393 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4394 if (
id == llvm::Intrinsic::not_intrinsic)
4396 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4398 return {id, std::move(args)};
4403 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4411 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4412 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4414 for (
Value v : thisOp.getCoordinates())
4418 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4419 llvm::Value *i64ZeroValue =
4420 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4421 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
4422 args.push_back(builder.getInt1(hasCacheHint));
4424 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
4426 constexpr unsigned numRedKinds = 8;
4427 constexpr unsigned numLayouts = 2;
4428 constexpr unsigned maxDim = 5;
4429 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4430 using layoutTable = std::array<row, numLayouts>;
4431 using fullTable = std::array<layoutTable, numRedKinds>;
4432 static constexpr fullTable IDTable{
4435 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4436 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4437 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4438 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4439 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4441 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4442 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4443 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4446 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4447 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4448 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4449 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4450 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4452 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4453 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4454 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4457 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4458 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4459 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4460 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4461 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4463 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4464 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4465 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4468 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4469 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4470 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4471 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4472 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4474 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4475 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4476 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4479 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4480 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4481 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4482 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4483 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4485 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4486 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4487 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4490 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4491 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4492 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4493 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4494 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4496 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4497 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4498 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4501 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4502 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4503 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4504 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4505 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4507 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4508 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4509 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4512 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4513 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4514 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4515 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4516 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4518 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4519 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4521 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4523 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4524 "TMAReduxKinds must match number of rows in IDTable");
4526 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
4527 size_t mode =
static_cast<size_t>(thisOp.getMode());
4528 size_t dim = thisOp.getCoordinates().size();
4530 assert(redKind < IDTable.size() &&
4531 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4532 assert(mode < IDTable[redKind].size() &&
4533 "Invalid mode for CpAsyncBulkTensorReduceOp");
4534 assert(dim < IDTable[redKind][mode].size() &&
4535 "Invalid dim for CpAsyncBulkTensorReduceOp");
4537 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4540 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4542 return {intrinsicID, std::move(args)};
4547#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4548 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4549 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4551#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4552 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4553 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4556ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4557 NVVM::SaturationMode sat,
bool hasRelu) {
4558 using RndMode = NVVM::FPRoundingMode;
4559 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4568 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4573ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4575 llvm::IRBuilderBase &builder) {
4580 bool hasRelu = op.getRelu();
4582 llvm::Intrinsic::ID intId =
4583 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4584 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4586 return {intId, std::move(args)};
4589#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4590 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4591 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4593llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4596 .Case([&](mlir::Float6E2M3FNType) {
4599 .Case([&](mlir::Float6E3M2FNType) {
4603 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4604 return llvm::Intrinsic::not_intrinsic;
4609ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4611 llvm::IRBuilderBase &builder) {
4613 bool hasRelu = op.getRelu();
4615 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4617 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4618 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4619 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4624 return {intId, std::move(args)};
4628ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4630 llvm::IRBuilderBase &builder) {
4632 bool hasRelu = op.getRelu();
4634 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4636 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4637 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4638 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4643 return {intId, std::move(args)};
4646llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4649 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4650 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4651 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4653 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4654 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4655 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4658 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF6x2Op");
4659 return llvm::Intrinsic::not_intrinsic;
4663llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4666 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4668 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4669 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4671 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4673 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4674 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4677 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF6x2Op");
4678 return llvm::Intrinsic::not_intrinsic;
4682#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4683 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4684 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4686#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4687 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4688 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4691ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4692 NVVM::SaturationMode sat,
bool hasRelu) {
4693 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4694 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4695 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4698 .Case([&](mlir::Float8E4M3FNType) {
4701 .Case([&](mlir::Float8E5M2Type) {
4704 .Case([&](mlir::Float8E8M0FNUType) {
4705 if (hasRoundingModeRZ)
4707 else if (hasRoundingModeRP)
4710 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4713 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4714 return llvm::Intrinsic::not_intrinsic;
4718#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4719 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4720 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4722llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4725 .Case([&](mlir::Float8E4M3FNType) {
4728 .Case([&](mlir::Float8E5M2Type) {
4732 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4733 return llvm::Intrinsic::not_intrinsic;
4738ConvertBF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4739 NVVM::FPRoundingMode rnd,
4740 NVVM::SaturationMode sat,
bool hasRelu) {
4741 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4743 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4744 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4745 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4746 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4747 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4751 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4753 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4754 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4756 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4758 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4759 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4761 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4762 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4763 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4764 return ue8m0x2IDs[
index];
4767 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
4768 return llvm::Intrinsic::not_intrinsic;
4774 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4776 bool hasRelu = curOp.getRelu();
4778 llvm::Intrinsic::ID intId =
4780 .Case([&](Float8E4M3FNType type) {
4781 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4782 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4784 .Case([&](Float8E5M2Type type) {
4785 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4786 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4789 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4790 return llvm::Intrinsic::not_intrinsic;
4793 llvm::Value *packedI16 =
4794 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4795 llvm::Type::getInt16Ty(builder.getContext()));
4797 return {intId, {packedI16}};
4802 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4803 bool hasScale =
static_cast<bool>(curOp.getScaleFactor());
4804 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4805 bool hasRelu = curOp.getRelu();
4807 static constexpr llvm::Intrinsic::ID E4M3Ids[] = {
4808 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_scale_n2_ue8m0,
4809 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4810 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4811 llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4814 static constexpr llvm::Intrinsic::ID E5M2Ids[] = {
4815 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_scale_n2_ue8m0,
4816 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4817 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4818 llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4821 llvm::Intrinsic::ID intId =
4823 .Case([&](Float8E8M0FNUType type) {
4824 return llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4826 .Case([&](Float8E4M3FNType type) {
4827 return E4M3Ids[hasSatfinite << 1 | hasRelu];
4829 .Case([&](Float8E5M2Type type) {
4830 return E5M2Ids[hasSatfinite << 1 | hasRelu];
4833 llvm_unreachable(
"Invalid type for ConvertF8x2ToBF16x2Op");
4834 return llvm::Intrinsic::not_intrinsic;
4836 llvm::Value *packedI16 =
4837 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4838 llvm::Type::getInt16Ty(builder.getContext()));
4841 args.push_back(packedI16);
4842 if (!isa<Float8E8M0FNUType>(curOp.getSrcType()))
4845 : builder.getInt16(0x7f7f));
4848 return {intId, std::move(args)};
4853 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4855 bool hasRelu = curOp.getRelu();
4857 llvm::Intrinsic::ID intId =
4859 .Case([&](Float6E2M3FNType type) {
4860 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4861 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4863 .Case([&](Float6E3M2FNType type) {
4864 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4865 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4868 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4869 return llvm::Intrinsic::not_intrinsic;
4872 llvm::Value *packedI16 =
4873 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4874 llvm::Type::getInt16Ty(builder.getContext()));
4876 return {intId, {packedI16}};
4881 auto curOp = cast<NVVM::ConvertF6x2ToBF16x2Op>(op);
4882 bool hasScale =
static_cast<bool>(curOp.getScaleFactor());
4883 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4884 bool hasRelu = curOp.getRelu();
4886 static constexpr llvm::Intrinsic::ID E2M3Ids[] = {
4887 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_scale_n2_ue8m0,
4888 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4889 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4890 llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4893 static constexpr llvm::Intrinsic::ID E3M2Ids[] = {
4894 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_scale_n2_ue8m0,
4895 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4896 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4897 llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4900 unsigned idx = (hasSatfinite << 1) | hasRelu;
4901 llvm::Intrinsic::ID intId =
4903 .Case([&](Float6E2M3FNType type) {
return E2M3Ids[idx]; })
4904 .Case([&](Float6E3M2FNType type) {
return E3M2Ids[idx]; })
4906 llvm_unreachable(
"Invalid type for ConvertF6x2ToBF16x2Op");
4907 return llvm::Intrinsic::not_intrinsic;
4910 llvm::Value *packedI16 =
4911 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4912 llvm::Type::getInt16Ty(builder.getContext()));
4915 args.push_back(packedI16);
4922 return {intId, std::move(args)};
4927 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4929 bool hasRelu = curOp.getRelu();
4931 llvm::Intrinsic::ID intId =
4933 .Case([&](Float4E2M1FNType type) {
4934 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4935 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4938 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4939 return llvm::Intrinsic::not_intrinsic;
4942 llvm::Value *extendedI16 =
4943 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4944 llvm::Type::getInt16Ty(builder.getContext()));
4946 return {intId, {extendedI16}};
4951 auto curOp = cast<NVVM::ConvertF4x2ToBF16x2Op>(op);
4952 bool hasScale =
static_cast<bool>(curOp.getScaleFactor());
4953 bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
4954 bool hasRelu = curOp.getRelu();
4956 static constexpr llvm::Intrinsic::ID E2M1Ids[] = {
4957 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_scale_n2_ue8m0,
4958 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4959 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4960 llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4963 unsigned idx = (hasSatfinite << 1) | hasRelu;
4964 llvm::Intrinsic::ID intId =
4966 .Case([&](Float4E2M1FNType type) {
return E2M1Ids[idx]; })
4968 llvm_unreachable(
"Invalid type for ConvertF4x2ToBF16x2Op");
4969 return llvm::Intrinsic::not_intrinsic;
4972 llvm::Value *extendedI16 =
4973 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4974 llvm::Type::getInt16Ty(builder.getContext()));
4977 args.push_back(extendedI16);
4984 return {intId, std::move(args)};
4989 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4990 bool hasRelu = thisOp.getRelu();
4991 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4993 llvm::Intrinsic::ID
id =
4995 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4996 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
5002 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
5003 : builder.getInt16(0x7f7f));
5004 return {id, std::move(args)};
5009 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
5010 bool hasRelu = thisOp.getRelu();
5011 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
5013 llvm::Intrinsic::ID
id =
5016 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
5017 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
5022 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
5023 : builder.getInt16(0x7f7f));
5024 return {id, std::move(args)};
5029 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
5030 bool hasRelu = thisOp.getRelu();
5031 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
5032 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
5034 static constexpr llvm::Intrinsic::ID ids[] = {
5035 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
5036 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
5037 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
5038 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
5041 unsigned idx = (hasSat << 1) | hasRelu;
5045 llvm::Value *packedI16 =
5046 builder.CreateBitCast(mt.
lookupValue(thisOp.getSrc()),
5047 llvm::Type::getInt16Ty(builder.getContext()));
5048 args.push_back(packedI16);
5049 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
5050 : builder.getInt16(0x7f7f));
5052 return {ids[idx], std::move(args)};
5056Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
5059 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
5060 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
5062 bool isShared = as == NVVMMemorySpace::Shared;
5063 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
5065 llvm::Intrinsic::ID id;
5067 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
5068 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
5070 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
5071 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
5081llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
5084 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
5085 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
5086 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
5087 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
5096#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
5097 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
5098 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
5100#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
5101 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
5102 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
5105Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
5108 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
5109 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
5111 bool isShared = as == NVVMMemorySpace::Shared;
5112 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
5113 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
5115 llvm::Intrinsic::ID
id =
5122 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
5127#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
5128 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
5130#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
5131 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
5132 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
5134#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
5136 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
5137 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
5138 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
5139 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
5140 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
5144ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
5146 llvm::IRBuilderBase &builder) {
5147 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5148 llvm::Intrinsic::nvvm_ff2f16x2_rn,
5149 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
5150 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
5151 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
5153 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5154 llvm::Intrinsic::nvvm_ff2f16x2_rz,
5155 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
5156 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
5157 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
5159 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5160 llvm::Intrinsic::nvvm_ff2f16x2_rs,
5161 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
5162 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
5163 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
5166 unsigned hasRelu = op.getRelu() ? 1 : 0;
5167 unsigned hasSatFinite =
5168 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5171 unsigned idx = (hasSatFinite << 1) | hasRelu;
5176 if (op.getRandomBits())
5177 args.push_back(mt.
lookupValue(op.getRandomBits()));
5179 switch (op.getRnd()) {
5180 case FPRoundingMode::RN:
5181 return {rndRNIds[idx], std::move(args)};
5182 case FPRoundingMode::RZ:
5183 return {rndRZIds[idx], std::move(args)};
5184 case FPRoundingMode::RS:
5185 return {rndRSIds[idx], std::move(args)};
5187 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
5192ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
5194 llvm::IRBuilderBase &builder) {
5195 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5196 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
5197 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
5198 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
5199 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
5201 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5202 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
5203 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
5204 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
5205 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
5207 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5208 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
5209 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
5210 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
5211 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
5214 unsigned hasRelu = op.getRelu() ? 1 : 0;
5215 unsigned hasSatFinite =
5216 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5219 unsigned idx = (hasSatFinite << 1) | hasRelu;
5224 if (op.getRandomBits())
5225 args.push_back(mt.
lookupValue(op.getRandomBits()));
5227 switch (op.getRnd()) {
5228 case FPRoundingMode::RN:
5229 return {rndRNIds[idx], std::move(args)};
5230 case FPRoundingMode::RZ:
5231 return {rndRZIds[idx], std::move(args)};
5232 case FPRoundingMode::RS:
5233 return {rndRSIds[idx], std::move(args)};
5235 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
5239llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
5241 bool hasRelu = getRelu();
5244 .Case([&](mlir::Float8E4M3FNType) {
5245 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
5246 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
5248 .Case([&](mlir::Float8E5M2Type) {
5249 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
5250 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
5253 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
5254 return llvm::Intrinsic::not_intrinsic;
5258llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
5260 bool hasRelu = getRelu();
5263 .Case([&](mlir::Float6E2M3FNType) {
5264 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
5265 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
5267 .Case([&](mlir::Float6E3M2FNType) {
5268 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
5269 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
5272 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
5273 return llvm::Intrinsic::not_intrinsic;
5277llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
5279 bool hasRelu = getRelu();
5282 .Case([&](mlir::Float4E2M1FNType) {
5283 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
5284 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
5287 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
5288 return llvm::Intrinsic::not_intrinsic;
5292llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
5293 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
5294 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
5295 auto srcFmt = curOp.getSrcFormat();
5296 auto mc = curOp.getMulticast();
5298 switch (curOp.getShape()) {
5299 case Tcgen05CpShape::SHAPE_128x256b:
5301 case Tcgen05CpShape::SHAPE_128x128b:
5303 case Tcgen05CpShape::SHAPE_4x256b:
5305 case Tcgen05CpShape::SHAPE_32x128b:
5307 case Tcgen05CpShape::SHAPE_64x128b:
5308 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
5312 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
5319 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
5321 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
5326LogicalResult Tcgen05LdOp::verify() {
5328 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5331 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
5332 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
5334 auto resTy = getRes().getType();
5335 unsigned resLen = isa<VectorType>(resTy)
5336 ? llvm::cast<VectorType>(resTy).getNumElements()
5339 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
5340 "{1} in tcgen05.ld Op",
5341 resLen, stringifyEnum(
getShape())));
5346LogicalResult Tcgen05StOp::verify() {
5348 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5351 auto valTy = getVal().getType();
5352 unsigned valLen = isa<VectorType>(valTy)
5353 ? llvm::cast<VectorType>(valTy).getNumElements()
5356 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
5357 "{1} in tcgen05.st Op",
5358 valLen, stringifyEnum(
getShape())));
5368 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
5369 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
5370 rangeAttr.getLower(), rangeAttr.getUpper()});
5380 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
5384 const llvm::APInt &lower = rangeAttr->getLower();
5385 const llvm::APInt &upper = rangeAttr->getUpper();
5388 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
5389 unsigned bitWidth = lower.getBitWidth();
5390 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
5391 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
5393 "invalid range attribute: Lower == Upper, but they aren't min (")
5394 << llvm::toString(minVal, 10,
false) <<
") or max ("
5395 << llvm::toString(maxVal, 10,
false)
5396 <<
") value! This is an invalid constant range.";
5403 llvm::IRBuilderBase &builder) {
5404 return builder.CreateBitCast(arg,
5405 llvm::Type::getInt32Ty(builder.getContext()));
5410 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
5417 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5418 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5419 unsigned type = (isASigned << 1) | isBSigned;
5420 const llvm::Intrinsic::ID ids[] = {
5421 llvm::Intrinsic::nvvm_idp4a_u_u,
5422 llvm::Intrinsic::nvvm_idp4a_u_s,
5423 llvm::Intrinsic::nvvm_idp4a_s_u,
5424 llvm::Intrinsic::nvvm_idp4a_s_s,
5426 return {ids[type], args};
5431 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
5436 args.push_back(builder.getInt1(curOp.getBHi()));
5439 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5440 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5441 unsigned type = (isASigned << 1) | isBSigned;
5442 const llvm::Intrinsic::ID ids[] = {
5443 llvm::Intrinsic::nvvm_idp2a_u_u,
5444 llvm::Intrinsic::nvvm_idp2a_u_s,
5445 llvm::Intrinsic::nvvm_idp2a_s_u,
5446 llvm::Intrinsic::nvvm_idp2a_s_s,
5448 return {ids[type], args};
5452 llvm::IRBuilderBase &builder) {
5453 return builder.CreateAddrSpaceCast(
5454 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
5458PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
5460 llvm::IRBuilderBase &builder) {
5461 using MemSpace = NVVM::NVVMMemorySpace;
5462 using CacheLevel = NVVM::PrefetchCacheLevel;
5464 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
5465 std::optional<NVVM::CacheEvictionPriority> evictPriority =
5466 op.getEvictPriority();
5467 unsigned addressSpace =
5468 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5476 if (op.getTensormap())
5477 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5479 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
5481 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5482 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5484 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5485 switch (*evictPriority) {
5486 case NVVM::CacheEvictionPriority::EvictLast:
5487 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5488 case NVVM::CacheEvictionPriority::EvictNormal:
5489 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5491 llvm_unreachable(
"Invalid cache eviction priority");
5495 switch (
static_cast<MemSpace
>(addressSpace)) {
5496 case MemSpace::Generic:
5497 return *cacheLevel == CacheLevel::L1
5499 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5500 case MemSpace::Global:
5501 return *cacheLevel == CacheLevel::L1
5503 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5505 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5506 case MemSpace::Local:
5507 return *cacheLevel == CacheLevel::L1
5509 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5511 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5513 llvm_unreachable(
"Invalid pointer address space");
5517bool NVVM::InlinePtxOp::getAsmValues(
5521 for (
auto arg : getReadWriteArgs())
5523 for (
auto arg : getResults())
5525 for (
auto arg : getReadOnlyArgs())
5532NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5534 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5536 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
5537 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
5539 llvm::Intrinsic::ID intrinsicID =
5540 curOp.getMulticast()
5542 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5543 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5545 return {intrinsicID, args};
5548NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5550 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5552 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
5554 llvm::Intrinsic::ID intrinsicID;
5556 switch (curOp.getQueryType()) {
5557 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5559 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5561 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5562 intrinsicID = llvm::Intrinsic::
5563 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5565 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5566 intrinsicID = llvm::Intrinsic::
5567 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5569 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5570 intrinsicID = llvm::Intrinsic::
5571 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5574 return {intrinsicID, args};
5579 llvm::IRBuilderBase &builder) {
5580 auto thisOp = cast<NVVM::PermuteOp>(op);
5581 NVVM::PermuteMode mode = thisOp.getMode();
5583 static constexpr llvm::Intrinsic::ID IDs[] = {
5584 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5585 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5586 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5587 llvm::Intrinsic::nvvm_prmt_rc16};
5589 unsigned modeIndex =
static_cast<unsigned>(mode);
5597 args.push_back(mt.
lookupValue(thisOp.getSelector()));
5599 return {IDs[modeIndex], args};
5604 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5608 if (thisOp.getOrd())
5609 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5610 if (thisOp.getNewValue())
5611 args.push_back(mt.
lookupValue(thisOp.getNewValue()));
5612 if (
auto attr = thisOp.getNewValueAttr()) {
5615 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5616 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5617 TensormapFillModeAttr>([](
auto attr) {
5618 return static_cast<unsigned>(attr.getValue());
5620 .Default([](
auto attr) {
5621 llvm_unreachable(
"Invalid attribute type");
5624 args.push_back(builder.getInt32(val));
5627 static constexpr llvm::Intrinsic::ID IDs[] = {
5628 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5629 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5630 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5631 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5632 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5633 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5634 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5635 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5636 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5637 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5638 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5641 unsigned fieldIndex =
static_cast<unsigned>(thisOp.getField());
5643 return {IDs[fieldIndex], args};
5652 llvm::IRBuilderBase &builder) {
5654 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5657 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5660 const bool isATensor = isa<llvm::PointerType>(
A->getType());
5663 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5664 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5665 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5667 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5668 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5669 using IsATensorArray = std::array<CtaGroupArray, 2>;
5670 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5671 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5674 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5680 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
5682 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
5686 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5687 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5691 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5692 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5698 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
5700 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
5704 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5705 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5709 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5710 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5716 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5719 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5724 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5726 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5731 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5733 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5739 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5743 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5748 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5750 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5754 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5756 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5759 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5760 bool hasScaleInputD = ScaleInputD !=
nullptr;
5762 llvm::Value *DisableOutputLane =
5764 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5766 const unsigned ctaGroup =
5769 llvm::Intrinsic::ID ID =
5770 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5771 [ctaGroup - 1][thisOp.getAShift()];
5773 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
5776 args.push_back(ScaleInputD);
5778 if (hasDisableOutputLane)
5779 args.push_back(DisableOutputLane);
5781 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5783 if (!hasDisableOutputLane)
5784 args.push_back(builder.getInt32(ctaGroup));
5787 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5794 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
5795 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
5797 if (disableOutputLane) {
5798 mlir::VectorType disableOutputLaneType =
5799 cast<mlir::VectorType>(disableOutputLane.
getType());
5800 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5801 disableOutputLaneType.getNumElements() != 4) ||
5802 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5803 disableOutputLaneType.getNumElements() != 8))
5804 return emitError(loc) <<
"Disable Output Lane of length "
5805 << disableOutputLaneType.getNumElements()
5806 <<
" is incompatible with CtaGroupAttr";
5809 if (hasAShift && !isATensor)
5811 loc,
"A-shift can be applied only when matrix A is in tensor memory");
5813 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5814 collectorOp == Tcgen05MMACollectorOp::USE))
5816 loc,
"Cannot use collector buffer operation fill or use with ashift");
5821LogicalResult Tcgen05MMAOp::verify() {
5823 getDisableOutputLane(), getCtaGroup(), getAShift(),
5824 getCollectorOp(), getLoc());
5834 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5837 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5840 bool isATensor = isa<llvm::PointerType>(
A->getType());
5843 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5844 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5845 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5846 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5848 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5849 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5850 using IsATensorArray = std::array<CtaGroupArray, 2>;
5851 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5852 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5855 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5861 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
5863 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
5867 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5868 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5872 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5873 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5879 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5882 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5887 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5888 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5892 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5893 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5900 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5904 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5909 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5911 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5916 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5918 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5924 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5928 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5933 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5935 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5939 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5941 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5944 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5945 bool hasScaleInputD = ScaleInputD !=
nullptr;
5947 llvm::Value *DisableOutputLane =
5949 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5954 llvm::Intrinsic::ID ID =
5955 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5956 [ctaGroup - 1][thisOp.getAShift()];
5958 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5961 args.push_back(ScaleInputD);
5963 if (hasDisableOutputLane)
5964 args.push_back(DisableOutputLane);
5966 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5968 if (!hasDisableOutputLane)
5969 args.push_back(builder.getInt32(ctaGroup));
5972 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5977LogicalResult Tcgen05MMASparseOp::verify() {
5979 getDisableOutputLane(), getCtaGroup(), getAShift(),
5980 getCollectorOp(), getLoc());
5990 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5993 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5996 bool isATensor = isa<llvm::PointerType>(
A->getType());
5999 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
6000 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
6001 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6002 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
6003 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
6004 args.push_back(builder.getInt32(
6007 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6009 auto kind = thisOp.getKind();
6010 auto blockScale = thisOp.getBlockScale();
6011 llvm::Intrinsic::ID ID = [&]() {
6012 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
6013 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6014 return isATensor ? llvm::Intrinsic::
6015 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
6017 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
6018 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6021 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
6023 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
6025 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
6026 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6028 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
6029 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
6030 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6031 return isATensor ? llvm::Intrinsic::
6032 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
6034 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
6036 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
6037 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6040 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
6042 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
6044 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
6047 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
6049 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
6052 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
6059 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
6060 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
6061 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
6062 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
6063 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
6065 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
6066 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
6068 llvm::formatv(
"{} kind does not support block16 attribute",
6069 stringifyEnum(kind)));
6074LogicalResult Tcgen05MMABlockScaleOp::verify() {
6076 getBlockScale(), getLoc());
6086 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
6089 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
6092 bool isATensor = isa<llvm::PointerType>(
A->getType());
6095 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
6096 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
6097 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6098 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
6099 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
6100 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
6101 args.push_back(builder.getInt32(
6104 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6106 auto kind = thisOp.getKind();
6107 auto blockScale = thisOp.getBlockScale();
6108 llvm::Intrinsic::ID ID = [&]() {
6109 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
6110 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6111 return isATensor ? llvm::Intrinsic::
6112 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
6114 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
6115 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6118 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
6120 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
6122 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
6123 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
6124 return isATensor ? llvm::Intrinsic::
6125 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
6127 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
6128 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6131 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
6133 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
6135 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
6136 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
6139 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
6141 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
6143 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
6146 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
6148 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
6151 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
6157LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
6159 getBlockScale(), getLoc());
6169 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
6172 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
6175 bool isATensor = isa<llvm::PointerType>(
A->getType());
6178 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
6179 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
6180 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6182 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6186 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
6187 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
6189 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
6190 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
6192 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
6194 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6196 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6208 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
6211 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
6214 bool isATensor = isa<llvm::PointerType>(
A->getType());
6217 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
6218 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
6219 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6220 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
6222 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6227 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
6228 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
6230 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
6231 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
6233 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
6235 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6237 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6246#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
6247 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
6251 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
6254 mlir::VectorType VecResTy =
6255 cast<mlir::VectorType>(thisOp.getData().getType());
6256 unsigned Num = VecResTy.getNumElements();
6257 bool IsFloat = thisOp.getRedVal().getType().isF32();
6259 llvm::Intrinsic::ID Shape32x32b[][2] = {
6270 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
6281 NVVM::Tcgen05LdStShape
shape = thisOp.getShape();
6282 unsigned ID = [&]() {
6285 unsigned idx = std::log2(Num);
6287 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
6288 return Shape32x32b[idx][IsFloat];
6289 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
6290 return Shape16x32bx2[idx][IsFloat];
6292 llvm_unreachable(
"unhandled tcgen05.ld lowering");
6298 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
6299 args.push_back(mt.
lookupValue(thisOp.getOffset()));
6302 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
6305 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getAbs())));
6306 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getNan())));
6311LogicalResult Tcgen05LdRedOp::verify() {
6312 VectorType data = cast<VectorType>(getData().
getType());
6313 Type redVal = getRedVal().getType();
6315 if (data.getElementType() != redVal)
6317 "type of reduction value and element type of vector data should match");
6319 if (getOp() != NVVM::ReductionKind::MIN &&
6320 getOp() != NVVM::ReductionKind::MAX)
6321 return emitError(
"only min and max reduction kinds are supported");
6323 if (redVal.
isInteger() && (getAbs() || getNan())) {
6324 return emitError(
"abs or nan is only applicable for f32 type");
6334void NVVMDialect::initialize() {
6337#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6340#define GET_ATTRDEF_LIST
6341#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
6346 allowUnknownOperations();
6347 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
6348 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
6351LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
6353 StringAttr attrName = attr.
getName();
6355 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
6356 if (!isa<LLVM::LLVMFuncOp>(op)) {
6357 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
6358 <<
"' attribute attached to unexpected op";
6363 if (attrName == NVVMDialect::getMaxntidAttrName() ||
6364 attrName == NVVMDialect::getReqntidAttrName() ||
6365 attrName == NVVMDialect::getClusterDimAttrName()) {
6366 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
6367 if (!values || values.empty() || values.size() > 3) {
6370 <<
"' attribute must be integer array with maximum 3 index";
6375 if (attrName == NVVMDialect::getMinctasmAttrName() ||
6376 attrName == NVVMDialect::getMaxnregAttrName() ||
6377 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
6378 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
6380 <<
"'" << attrName <<
"' attribute must be integer constant";
6384 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
6385 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
6386 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
6388 <<
"'" << attrName <<
"' attribute must be used along with " <<
"'"
6389 << NVVMDialect::getReqntidAttrName() <<
"' and " <<
"'"
6390 << NVVMDialect::getClusterDimAttrName() <<
"'";
6397LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
6398 unsigned regionIndex,
6401 auto funcOp = dyn_cast<FunctionOpInterface>(op);
6405 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
6406 StringAttr attrName = argAttr.
getName();
6407 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
6411 <<
"' attribute must be present only on kernel arguments";
6413 if (!isa<UnitAttr>(argAttr.
getValue()))
6414 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
6415 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
6418 <<
"' attribute requires the argument to also have attribute '"
6419 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
6430unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
6431 return static_cast<unsigned>(getValue());
6434bool NVVMMemorySpaceAttr::isValidLoad(
6435 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6436 const ::mlir::DataLayout *dataLayout,
6442bool NVVMMemorySpaceAttr::isValidStore(
6443 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6444 const ::mlir::DataLayout *dataLayout,
6450bool NVVMMemorySpaceAttr::isValidAtomicOp(
6451 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
6452 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
6455 assert(
false &&
"unimplemented, see TODO in the source.");
6459bool NVVMMemorySpaceAttr::isValidAtomicXchg(
6460 Type type, ptr::AtomicOrdering successOrdering,
6461 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
6462 const ::mlir::DataLayout *dataLayout,
6465 assert(
false &&
"unimplemented, see TODO in the source.");
6469bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6473 assert(
false &&
"unimplemented, see TODO in the source.");
6477bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6482 assert(
false &&
"unimplemented, see TODO in the source.");
6491 int optLevel, StringRef triple, StringRef chip,
6492 StringRef features, DictionaryAttr flags,
6494 if (optLevel < 0 || optLevel > 3) {
6495 emitError() <<
"The optimization level must be a number between 0 and 3.";
6498 if (triple.empty()) {
6499 emitError() <<
"The target triple cannot be empty.";
6503 emitError() <<
"The target chip cannot be empty.";
6506 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6507 return mlir::isa_and_nonnull<StringAttr>(attr);
6509 emitError() <<
"All the elements in the `link` array must be strings.";
6515LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
6516 if (!getVerifyTarget())
6519 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6522 "NVVM target attribute must be attached to a GPU module");
6525 const unsigned targetFullSmVersion =
6529 "Minimum NVVM target SM version is sm_20");
6533 ->
walk([&](Operation *op) {
6534 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6535 const NVVMCheckSMVersion requirement =
6536 reqOp.getRequiredMinSMVersion();
6538 op->
emitOpError() <<
"is not supported on " << getChip();
6550#define GET_OP_CLASSES
6551#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6553#define GET_ATTRDEF_CLASSES
6554#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred, TypeRange actual)
For ops with optional results, allow the user to omit the result even when inference would produce on...
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static llvm::Intrinsic::ID getBarrierReductionIntrinsic(bool aligned, NVVM::BarrierReduction kind)
Maps the (aligned, kind) pair to the @llvm.nvvm.barrier.cta.red.
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static llvm::Intrinsic::ID getBarrierSyncIntrinsic(bool aligned, bool hasCount)
Maps the (aligned, hasCount) pair to the @llvm.nvvm.barrier.cta.sync.
static constexpr unsigned notIntrinsic
static LogicalResult inferMBarrierArriveResultTypes(MLIRContext *context, Value addr, SmallVectorImpl< Type > &inferredReturnTypes)
Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) retu...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
static bool isMinimumSMVersion(unsigned fullSmVersion)
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString)
bool isCompatibleWith(const unsigned &targetFullSmVersion) const
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.