31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
305LogicalResult MBarrierExpectTxOp::verify() {
309LogicalResult MBarrierCompleteTxOp::verify() {
313LogicalResult MBarrierTestWaitOp::verify() {
317LogicalResult MBarrierTryWaitOp::verify() {
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
326 return emitError(
"Relu not supported with rna rounding mode.");
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
343 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 <<
" types are supported for conversions from f32x2 to f6x2.";
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
359 bool hasRelu = getRelu();
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
366 if (!isRoundingModeRN) {
367 return emitOpError(
"Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
370 << mlir::Float8E5M2Type::get(ctx) <<
" types";
373 return emitOpError(
"Only SATFINITE saturation mode is supported "
376 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
377 << mlir::Float8E5M2Type::get(ctx) <<
" types";
381 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError(
"Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
388 return emitOpError(
"relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
395 << mlir::Float8E4M3FNType::get(ctx) <<
", "
396 << mlir::Float8E5M2Type::get(ctx) <<
", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
399 "supported for conversions from f32x2 to f8x2";
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
408 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
409 << mlir::Float8E5M2Type::get(ctx)
410 <<
" types are supported for conversions from f16x2 to f8x2.";
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
420 <<
" type is supported for conversions from "
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError(
"Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
436 << mlir::Float4E2M1FNType::get(ctx)
437 <<
" type is supported for conversions from f32x2 to f4x2.";
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
447 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
448 << mlir::Float8E5M2Type::get(ctx)
449 <<
" types are supported for conversions from f8x2 to f16x2.";
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
458 << mlir::Float8E8M0FNUType::get(ctx)
459 <<
" type is supported for conversions from f8x2 to bf16x2.";
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
469 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 <<
" types are supported for conversions from f6x2 to f16x2.";
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
481 << mlir::Float4E2M1FNType::get(ctx)
482 <<
" type is supported for conversions from f4x2 to f16x2.";
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi =
static_cast<bool>(getHi());
496 return emitError(
"mode '") << getMode() <<
"' requires 'hi' operand.";
504 << getMode() <<
"' does not accept 'hi' operand.";
519 static constexpr FPRoundingMode validRndModes[] = {
520 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
522 if (!llvm::is_contained(validRndModes, rnd)) {
524 "Only RN, RZ, and RS rounding modes are supported for "
525 "conversions from f32x2 to ")
529 if (rnd == FPRoundingMode::RS) {
530 if (!hasRandomBits) {
531 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
536 "random_bits not supported for RN and RZ rounding modes.");
543LogicalResult ConvertF32x2ToF16x2Op::verify() {
545 getRandomBits() ?
true :
false, *
this);
548LogicalResult ConvertF32x2ToBF16x2Op::verify() {
550 getRandomBits() ?
true :
false, *
this);
553LogicalResult ConvertF32x4ToF8x4Op::verify() {
556 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
558 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
559 << mlir::Float8E5M2Type::get(ctx)
560 <<
" types are supported for conversions from f32x4 to f8x4.";
565LogicalResult ConvertF32x4ToF6x4Op::verify() {
568 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
570 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
571 << mlir::Float6E3M2FNType::get(ctx)
572 <<
" types are supported for conversions from f32x4 to f6x4.";
577LogicalResult ConvertF32x4ToF4x4Op::verify() {
580 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
581 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
582 <<
" type is supported for conversions from "
588LogicalResult BulkStoreOp::verify() {
589 if (getInitVal() != 0)
590 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
594LogicalResult PMEventOp::verify() {
595 auto eventId = getEventId();
596 auto maskedEventId = getMaskedEventId();
597 if (!maskedEventId && !eventId) {
598 return emitOpError() <<
"either `id` or `mask` must be set";
601 if (maskedEventId && eventId) {
602 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
606 if (eventId < 0 || eventId > 15) {
607 return emitOpError() <<
"`id` must be between 0 and 15";
611 return llvm::success();
617std::optional<mlir::NVVM::MMATypes>
618MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
620 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
621 if (operandElType.
isF64())
622 return NVVM::MMATypes::f64;
623 if (operandElType.
isF16() || operandElType == half2Type)
624 return NVVM::MMATypes::f16;
625 if (operandElType.
isF32() && isAccumulator)
626 return NVVM::MMATypes::f32;
627 if (operandElType.
isF32() && !isAccumulator)
628 return NVVM::MMATypes::tf32;
629 if (llvm::isa<IntegerType>(operandElType)) {
631 return NVVM::MMATypes::s32;
635 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
636 if (structType.getBody().empty())
638 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
645 return (type == MMATypes::u4 || type == MMATypes::s4);
649 return (type == MMATypes::u8 || type == MMATypes::s8);
654 type == MMATypes::s32;
657MMATypes MmaOp::accumPtxType() {
658 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
659 getODSOperands(2).getTypes().front(),
true);
660 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
664MMATypes MmaOp::resultPtxType() {
665 std::optional<mlir::NVVM::MMATypes> val =
666 inferOperandMMAType(getResult().
getType(),
true);
667 assert(val.has_value() &&
"result PTX type should always be inferrable");
673 struct MMAOperandFragment {
674 StringRef operandName;
675 StringRef ptxTypeAttr;
676 SmallVector<Value, 4> regs;
677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
681 std::array<MMAOperandFragment, 3> frags{
682 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
683 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
684 MMAOperandFragment(
"C",
"")};
686 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
688 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
689 auto &frag = frags[fragIdx];
690 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
691 for (
auto operandIdx = varOperandSpec.first;
692 operandIdx < varOperandSpec.first + varOperandSpec.second;
694 frag.regs.push_back(this->getOperand(operandIdx));
695 if (operandIdx == 0) {
696 regTypes.push_back(this->getOperand(operandIdx).
getType());
699 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
700 regTypes.back(), fragIdx >= 2);
702 ignoreAttrNames.push_back(frag.ptxTypeAttr);
705 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
706 p <<
" " << frag.operandName;
712 for (
const auto &frag : frags) {
713 printMmaOperand(frag);
722 frags[1].regs[0].getType(),
723 frags[2].regs[0].getType()},
732 std::optional<MMAIntOverflow> intOverflow,
733 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
734 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
736 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
741 result.addOperands(operandA);
742 result.addOperands(operandB);
743 result.addOperands(operandC);
745 if (multiplicandPtxTypes) {
746 result.addAttribute(
"multiplicandAPtxType",
747 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
748 result.addAttribute(
"multiplicandBPtxType",
749 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
751 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
752 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
753 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
754 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
757 if (multiplicandLayouts) {
758 result.addAttribute(
"layoutA",
759 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
760 result.addAttribute(
"layoutB",
761 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
763 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
764 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
767 if (intOverflow.has_value())
768 result.addAttribute(
"intOverflowBehavior",
769 MMAIntOverflowAttr::get(ctx, *intOverflow));
770 if (b1Op.has_value())
771 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
773 result.addTypes(resultType);
775 MmaOp::getOperandSegmentSizeAttr(),
777 static_cast<int32_t>(operandB.size()),
778 static_cast<int32_t>(operandC.size())}));
786 struct MMAOperandFragment {
787 std::optional<MMATypes> elemtype;
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
789 SmallVector<Type> regTypes;
793 std::array<MMAOperandFragment, 4> frags;
799 MMAOperandFragment &frag) -> LogicalResult {
829 if (operandTypes.size() != 3)
832 "expected one type for each operand segment but got " +
833 Twine(operandTypes.size()) +
" types");
834 for (
const auto &iter : llvm::enumerate(operandTypes)) {
835 auto &frag = frags[iter.index()];
836 frag.regTypes.resize(frag.regs.size(), iter.value());
840 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
847 frags[3].elemtype = inferOperandMMAType(resultType,
true);
849 std::array<StringRef, 2> names{
"multiplicandAPtxType",
850 "multiplicandBPtxType"};
851 for (
unsigned idx = 0; idx < names.size(); idx++) {
852 const auto &frag = frags[idx];
853 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
854 if (!frag.elemtype.has_value() && !attr.has_value()) {
857 "attribute " + names[idx] +
858 " is not provided explicitly and cannot be inferred");
860 if (!attr.has_value())
862 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
865 result.addTypes(resultType);
866 if (!namedAttributes.
empty())
867 result.addAttributes(namedAttributes);
868 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
870 static_cast<int32_t>(frags[0].regs.size()),
871 static_cast<int32_t>(frags[1].regs.size()),
872 static_cast<int32_t>(frags[2].regs.size()),
877LogicalResult MmaOp::verify() {
879 auto f16Ty = Float16Type::get(context);
880 auto i32Ty = IntegerType::get(context, 32);
881 auto f16x2Ty = VectorType::get(2, f16Ty);
882 auto f32Ty = Float32Type::get(context);
883 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
884 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
887 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
890 auto f16x2x2StructTy =
891 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
893 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
895 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
897 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
898 getShapeAttr().getK()};
904 AllowedShapes allowedShapes;
905 AllowedTypes expectedA;
906 AllowedTypes expectedB;
907 AllowedTypes expectedC;
912 if (mmaShape[0] == 16) {
914 Type multiplicandFragType;
915 switch (*getMultiplicandAPtxType()) {
918 multiplicandFragType = i32Ty;
919 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
920 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
924 multiplicandFragType = i32Ty;
925 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
926 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
930 multiplicandFragType = f16x2Ty;
931 expectedResult.push_back(f16x2x2StructTy);
932 expectedResult.push_back(f32x4StructTy);
946 return emitError(
"invalid shape or multiplicand type: ")
947 << getMultiplicandAPtxType().value();
951 expectedResult.push_back(s32x4StructTy);
952 expectedC.emplace_back(4, i32Ty);
953 multiplicandFragType = i32Ty;
955 expectedC.emplace_back(2, f16x2Ty);
956 expectedC.emplace_back(4, f32Ty);
959 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
960 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
961 expectedA.emplace_back(unitA, multiplicandFragType);
962 expectedB.emplace_back(unitB, multiplicandFragType);
963 allowedShapes.push_back({16, 8, kFactor});
964 allowedShapes.push_back({16, 8, kFactor * 2});
966 if (resultPtxType() != accumPtxType())
971 if (mmaShape[0] == 8) {
972 if (*getMultiplicandAPtxType() == MMATypes::f16) {
973 expectedA.emplace_back(2, f16x2Ty);
974 expectedB.emplace_back(2, f16x2Ty);
975 expectedResult.push_back(f16x2x4StructTy);
976 expectedResult.push_back(f32x8StructTy);
977 expectedC.emplace_back(4, f16x2Ty);
978 expectedC.emplace_back(8, f32Ty);
979 allowedShapes.push_back({8, 8, 4});
981 if (*getMultiplicandAPtxType() == MMATypes::f64) {
982 Type f64Ty = Float64Type::get(context);
983 expectedA.emplace_back(1, f64Ty);
984 expectedB.emplace_back(1, f64Ty);
985 expectedC.emplace_back(2, f64Ty);
986 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
988 allowedShapes.push_back({8, 8, 4});
991 expectedA.push_back({i32Ty});
992 expectedB.push_back({i32Ty});
993 expectedC.push_back({i32Ty, i32Ty});
994 expectedResult.push_back(s32x2StructTy);
996 allowedShapes.push_back({8, 8, 32});
998 allowedShapes.push_back({8, 8, 16});
999 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1000 allowedShapes.push_back({8, 8, 128});
1004 std::string errorMessage;
1005 llvm::raw_string_ostream errorStream(errorMessage);
1008 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1009 !llvm::is_contained(allowedShapes, mmaShape)) {
1010 errorStream <<
"unimplemented variant for MMA shape <";
1011 llvm::interleaveComma(mmaShape, errorStream);
1017 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1018 for (
const auto &iter : llvm::enumerate(
1020 auto spec = this->getODSOperandIndexAndLength(iter.index());
1022 operand_type_begin() + spec.first +
1024 bool match = llvm::is_contained(iter.value(), operandTySeg);
1027 errorStream <<
"Could not match types for the "
1028 << operandNames[iter.index()]
1029 <<
" operands; expected one of ";
1030 for (
const auto &x : iter.value()) {
1031 errorStream << x.size() <<
"x" << x[0] <<
" ";
1033 errorStream <<
"but got ";
1034 llvm::interleaveComma(operandTySeg, errorStream);
1040 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1041 return expectedResultType == getResult().getType();
1044 <<
"Could not match allowed types for the result; expected one of ";
1045 llvm::interleaveComma(expectedResult, errorStream);
1046 errorStream <<
" but got " << getResult().getType();
1051 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1052 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1060 if (!getIntOverflowBehavior())
1062 getIntOverflowBehaviorAttrName().strref() +
1070 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1071 getMultiplicandAPtxType() == MMATypes::f16);
1073 if (!isM8N8K4_F16) {
1075 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1076 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1077 "layoutB = #nvvm.mma_layout<col> for shape <")
1078 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1079 <<
"> with element types " << *getMultiplicandAPtxType() <<
" and "
1080 << *getMultiplicandBPtxType()
1081 <<
". Only m8n8k4 with f16 supports other layouts.";
1088MMATypes MmaSpOp::accumPtxType() {
1089 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1090 getODSOperands(2).getTypes().front(),
true);
1091 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1095MMATypes MmaSpOp::resultPtxType() {
1096 std::optional<mlir::NVVM::MMATypes> val =
1097 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1098 assert(val.has_value() &&
"result PTX type should always be inferrable");
1104 llvm::IRBuilderBase &builder) {
1105 auto thisOp = cast<NVVM::MmaSpOp>(op);
1113 auto intId = MmaSpOp::getIntrinsicID(
1114 thisOp.getShape().getM(), thisOp.getShape().getN(),
1115 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1116 thisOp.getOrderedMetadata(), thisOp.getKind(),
1117 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1118 thisOp.accumPtxType(), thisOp.resultPtxType());
1120 return {intId, args};
1125 struct MMAOperandFragment {
1126 StringRef operandName;
1127 StringRef ptxTypeAttr;
1128 SmallVector<Value, 4> regs;
1129 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1130 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1133 std::array<MMAOperandFragment, 5> frags{
1134 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1135 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1136 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1137 MMAOperandFragment(
"selector",
"")};
1139 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1142 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1143 auto &frag = frags[fragIdx];
1144 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1145 for (
auto operandIdx = varOperandSpec.first;
1146 operandIdx < varOperandSpec.first + varOperandSpec.second;
1148 frag.regs.push_back(this->getOperand(operandIdx));
1149 if (operandIdx == varOperandSpec.first) {
1150 regTypes.push_back(this->getOperand(operandIdx).
getType());
1153 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1154 regTypes.back(), fragIdx >= 2);
1156 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1160 frags[3].regs.push_back(getSparseMetadata());
1161 frags[4].regs.push_back(getSparsitySelector());
1163 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1164 p <<
" " << frag.operandName;
1170 for (
const auto &frag : frags)
1171 printMmaSpOperand(frag);
1176 for (
int i = 0; i < 3; ++i) {
1181 p <<
") -> " << getResult().getType();
1188 std::optional<MMAIntOverflow> intOverflow,
1189 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1191 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1196 result.addOperands(operandA);
1197 result.addOperands(operandB);
1198 result.addOperands(operandC);
1199 result.addOperands(sparseMetadata);
1200 result.addOperands(sparsitySelector);
1202 if (multiplicandPtxTypes) {
1203 result.addAttribute(
"multiplicandAPtxType",
1204 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1205 result.addAttribute(
"multiplicandBPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1208 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1209 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1210 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1211 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1214 if (intOverflow.has_value())
1215 result.addAttribute(
"intOverflowBehavior",
1216 MMAIntOverflowAttr::get(ctx, *intOverflow));
1218 result.addTypes(resultType);
1220 MmaSpOp::getOperandSegmentSizeAttr(),
1222 static_cast<int32_t>(operandB.size()),
1223 static_cast<int32_t>(operandC.size()), 1,
1228 struct MMAOperandFragment {
1229 std::optional<MMATypes> elemtype;
1230 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1231 SmallVector<Type> regTypes;
1235 std::array<MMAOperandFragment, 6> frags;
1240 auto parseMmaSpOperand = [&](StringRef operandName,
1241 MMAOperandFragment &frag) -> LogicalResult {
1252 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1254 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1256 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1258 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1260 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1276 if (operandTypes.size() != 3)
1279 "expected one type for each operand segment but got " +
1280 Twine(operandTypes.size()) +
" types");
1281 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1282 auto &frag = frags[iter.index()];
1283 frag.regTypes.resize(frag.regs.size(), iter.value());
1288 MmaOp::inferOperandMMAType(frag.regTypes[0],
1296 MmaOp::inferOperandMMAType(resultType,
true);
1311 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1312 "multiplicandBPtxType"};
1313 for (
unsigned idx = 0; idx < names.size(); idx++) {
1314 const auto &frag = frags[idx];
1315 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1316 if (!frag.elemtype.has_value() && !attr.has_value()) {
1319 "attribute " + names[idx] +
1320 " is not provided explicitly and cannot be inferred");
1322 if (!attr.has_value())
1324 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1327 result.addTypes(resultType);
1328 if (!namedAttributes.
empty())
1329 result.addAttributes(namedAttributes);
1330 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1332 static_cast<int32_t>(frags[0].regs.size()),
1333 static_cast<int32_t>(frags[1].regs.size()),
1334 static_cast<int32_t>(frags[2].regs.size()),
1341LogicalResult MmaSpOp::verify() {
1343 auto f16Ty = Float16Type::get(context);
1344 auto i32Ty = IntegerType::get(context, 32);
1345 auto f16x2Ty = VectorType::get(2, f16Ty);
1346 auto f32Ty = Float32Type::get(context);
1347 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1348 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1350 auto s32x4StructTy =
1351 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1352 auto f32x8StructTy =
1354 auto f16x2x2StructTy =
1355 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1356 auto f32x4StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1358 auto s32x2StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1361 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1362 getShapeAttr().getK()};
1368 AllowedShapes allowedShapes;
1369 AllowedTypes expectedA;
1370 AllowedTypes expectedB;
1371 AllowedTypes expectedC;
1376 if (mmaShape[0] == 16) {
1378 Type multiplicandFragType;
1379 switch (*getMultiplicandAPtxType()) {
1380 case MMATypes::tf32:
1382 multiplicandFragType = i32Ty;
1383 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1384 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1386 allowedShapes.push_back({16, 8, 8});
1387 allowedShapes.push_back({16, 8, 16});
1389 case MMATypes::bf16:
1391 multiplicandFragType = i32Ty;
1392 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1393 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1395 allowedShapes.push_back({16, 8, 16});
1396 allowedShapes.push_back({16, 8, 32});
1400 multiplicandFragType = f16x2Ty;
1401 expectedResult.push_back(f16x2x2StructTy);
1402 expectedResult.push_back(f32x4StructTy);
1404 allowedShapes.push_back({16, 8, 16});
1405 allowedShapes.push_back({16, 8, 32});
1411 allowedShapes.push_back({16, 8, 64});
1412 allowedShapes.push_back({16, 8, 128});
1418 allowedShapes.push_back({16, 8, 32});
1419 allowedShapes.push_back({16, 8, 64});
1421 case MMATypes::e4m3:
1422 case MMATypes::e5m2:
1423 case MMATypes::e3m2:
1424 case MMATypes::e2m3:
1425 case MMATypes::e2m1:
1427 multiplicandFragType = i32Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1431 allowedShapes.push_back({16, 8, 64});
1434 return emitError(
"invalid shape or multiplicand type: ")
1435 << getMultiplicandAPtxType().value();
1439 expectedResult.push_back(s32x4StructTy);
1440 expectedC.emplace_back(4, i32Ty);
1441 multiplicandFragType = i32Ty;
1442 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1443 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1445 expectedC.emplace_back(2, f16x2Ty);
1446 expectedC.emplace_back(4, f32Ty);
1448 expectedC.emplace_back(2, f16x2Ty);
1449 expectedC.emplace_back(4, f32Ty);
1454 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1455 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1456 expectedA.emplace_back(unitA, multiplicandFragType);
1457 expectedB.emplace_back(unitB, multiplicandFragType);
1459 if (resultPtxType() != accumPtxType())
1464 if (mmaShape[0] == 8) {
1465 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1466 expectedA.emplace_back(2, f16x2Ty);
1467 expectedB.emplace_back(2, f16x2Ty);
1468 expectedResult.push_back(f16x2x4StructTy);
1469 expectedResult.push_back(f32x8StructTy);
1470 expectedC.emplace_back(4, f16x2Ty);
1471 expectedC.emplace_back(8, f32Ty);
1472 allowedShapes.push_back({8, 8, 4});
1474 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1475 Type f64Ty = Float64Type::get(context);
1476 expectedA.emplace_back(1, f64Ty);
1477 expectedB.emplace_back(1, f64Ty);
1478 expectedC.emplace_back(2, f64Ty);
1479 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1481 allowedShapes.push_back({8, 8, 4});
1484 expectedA.push_back({i32Ty});
1485 expectedB.push_back({i32Ty});
1486 expectedC.push_back({i32Ty, i32Ty});
1487 expectedResult.push_back(s32x2StructTy);
1489 allowedShapes.push_back({8, 8, 32});
1491 allowedShapes.push_back({8, 8, 16});
1495 std::string errorMessage;
1496 llvm::raw_string_ostream errorStream(errorMessage);
1499 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1500 !llvm::is_contained(allowedShapes, mmaShape)) {
1501 errorStream <<
"unimplemented variant for MMA shape <";
1502 llvm::interleaveComma(mmaShape, errorStream);
1508 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1509 for (
const auto &iter : llvm::enumerate(
1511 auto spec = this->getODSOperandIndexAndLength(iter.index());
1513 operand_type_begin() + spec.first +
1515 bool match = llvm::is_contained(iter.value(), operandTySeg);
1518 errorStream <<
"Could not match types for the "
1519 << operandNames[iter.index()]
1520 <<
" operands; expected one of ";
1521 for (
const auto &x : iter.value()) {
1522 errorStream << x.size() <<
"x" << x[0] <<
" ";
1524 errorStream <<
"but got ";
1525 llvm::interleaveComma(operandTySeg, errorStream);
1531 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1532 return expectedResultType == getResult().getType();
1535 <<
"Could not match allowed types for the result; expected one of ";
1536 llvm::interleaveComma(expectedResult, errorStream);
1537 errorStream <<
" but got " << getResult().getType();
1545 if (!getIntOverflowBehavior())
1547 getIntOverflowBehaviorAttrName().strref() +
1552 if (!getSparseMetadata().
getType().isInteger(32)) {
1553 return emitOpError() <<
"sparse metadata must be i32 type";
1557 if (!getSparsitySelector().
getType().isInteger(32)) {
1558 return emitOpError() <<
"sparsity selector must be i32 type";
1570struct MMAOperandFragment {
1571 StringRef operandName;
1572 StringRef ptxTypeAttr;
1573 SmallVector<Value, 4> regs;
1574 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1575 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1582 p <<
" " << name <<
"[";
1601template <
typename Op>
1606 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1607 auto &frag = frags[fragIdx];
1608 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1609 for (
auto operandIdx = varOperandSpec.first;
1610 operandIdx < varOperandSpec.first + varOperandSpec.second;
1612 frag.regs.push_back(op.getOperand(operandIdx));
1613 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1614 regTypes.push_back(op.getOperand(operandIdx).getType());
1618 regTypes.push_back(frag.regs[0].getType());
1620 std::optional<MMATypes> inferredType =
1621 MmaOp::inferOperandMMAType(regTypes.back(),
1624 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1635 auto typeParser = [&]() {
1639 operandTypes.push_back(ty);
1645 if (operandTypes.size() != 3)
1647 "expected exactly 3 types");
1656 if (!attrs.
get(
"multiplicandAPtxType")) {
1657 if (
auto inferredType =
1658 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1659 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1662 if (!attrs.
get(
"multiplicandBPtxType")) {
1663 if (
auto inferredType =
1664 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1665 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1671template <
typename OpType>
1674 ScaleVecSize scaleVecSize,
1675 BlockScaleFormat blockScaleFormat,
1676 MMABlockScaleKind kind) {
1678 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1679 properties.setShape(
1681 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1682 properties.setBlockScaleFormat(
1683 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1684 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1691 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1692 if (multiplicandPtxTypes) {
1693 result.addAttribute(
"multiplicandAPtxType",
1694 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1695 result.addAttribute(
"multiplicandBPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1698 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1699 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1700 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1701 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1706template <
typename OpTy>
1708 return *MmaOp::inferOperandMMAType(
1709 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1719 std::array<MMAOperandFragment, 3> frags{
1720 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1721 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1722 MMAOperandFragment(
"C",
"")};
1724 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1729 for (
const auto &frag : frags)
1734 {getScaleAData(), getByteIdA(), getThreadIdA()});
1736 {getScaleBData(), getByteIdB(), getThreadIdB()});
1743 frags[1].regs[0].getType(),
1744 frags[2].regs[0].getType()},
1750ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1752 struct LocalOperandFragment {
1753 std::optional<MMATypes> elemtype;
1754 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1758 std::array<LocalOperandFragment, 3> frags;
1787 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1788 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1791 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1801 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1811 result.addAttributes(namedAttributes);
1815 result.addTypes(resultTypes);
1816 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1818 static_cast<int32_t>(frags[0].regs.size()),
1819 static_cast<int32_t>(frags[1].regs.size()),
1820 static_cast<int32_t>(frags[2].regs.size()),
1831void MmaBlockScaleOp::build(
1836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1837 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1838 MMABlockScaleKind kind) {
1839 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1842 blockScaleFormat, kind);
1844 result.addOperands(operandA);
1845 result.addOperands(operandB);
1846 result.addOperands(operandC);
1848 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1851 multiplicandPtxTypes);
1853 result.addTypes(resultType);
1854 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1856 static_cast<int32_t>(operandA.size()),
1857 static_cast<int32_t>(operandB.size()),
1858 static_cast<int32_t>(operandC.size()),
1870 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1874 for (
Value operand : curOp.getOperandA())
1876 for (
Value operand : curOp.getOperandB())
1878 for (
Value operand : curOp.getOperandC())
1882 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1883 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1884 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1885 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1886 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1887 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1889 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1890 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1891 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1893 curOp.getBlockScaleFormat(), curOp.getKind());
1895 return {intId, args};
1898LogicalResult MmaBlockScaleOp::verify() {
1904 if (m == 16 && n == 8 && k == 64) {
1905 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1906 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1908 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1909 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1910 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1912 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1913 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1915 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1916 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1917 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1918 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1919 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1920 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
1921 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
1923 "attributes for mma.m16n8k64.mxf4nvf4");
1927 }
else if (m == 16 && n == 8 && k == 32) {
1928 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1929 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1930 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1932 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
1933 "attributes for mma.m16n8k32");
1946 std::array<MMAOperandFragment, 3> frags{
1947 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1948 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1949 MMAOperandFragment(
"C",
"")};
1951 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1956 for (
const auto &frag : frags)
1965 {getScaleAData(), getByteIdA(), getThreadIdA()});
1967 {getScaleBData(), getByteIdB(), getThreadIdB()});
1974 frags[1].regs[0].getType(),
1975 frags[2].regs[0].getType()},
1981ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
1983 struct LocalOperandFragment {
1984 std::optional<MMATypes> elemtype;
1985 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1989 std::array<LocalOperandFragment, 3> frags;
2025 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2026 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2029 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2038 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2051 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2061 result.addAttributes(namedAttributes);
2066 if (!
result.attributes.get(
"orderedMetadata"))
2069 result.addTypes(resultTypes);
2070 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2072 static_cast<int32_t>(frags[0].regs.size()),
2073 static_cast<int32_t>(frags[1].regs.size()),
2074 static_cast<int32_t>(frags[2].regs.size()),
2087void MmaSpBlockScaleOp::build(
2093 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2094 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2095 MMABlockScaleKind kind) {
2096 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2099 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2102 result.addOperands(operandA);
2103 result.addOperands(operandB);
2104 result.addOperands(operandC);
2105 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2106 threadIdA, scaleBData, byteIdB, threadIdB});
2109 multiplicandPtxTypes);
2111 result.addTypes(resultType);
2112 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2114 static_cast<int32_t>(operandA.size()),
2115 static_cast<int32_t>(operandB.size()),
2116 static_cast<int32_t>(operandC.size()),
2130 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2134 for (
Value operand : curOp.getOperandA())
2136 for (
Value operand : curOp.getOperandB())
2138 for (
Value operand : curOp.getOperandC())
2142 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2143 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2146 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2147 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2148 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2149 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2150 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2151 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2153 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2154 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2155 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2157 curOp.getBlockScaleFormat(), curOp.getKind());
2159 return {intId, args};
2162LogicalResult MmaSpBlockScaleOp::verify() {
2164 if (!getOrderedMetadata()) {
2165 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2173 if (m == 16 && n == 8 && k == 128) {
2174 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2175 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2177 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2178 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2179 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2181 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2182 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2184 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2185 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2186 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2187 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2188 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2189 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2190 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2192 "attributes for mma.m16n8k128.mxf4nvf4");
2196 }
else if (m == 16 && n == 8 && k == 64) {
2197 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2198 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2199 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2201 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2202 "attributes for mma.m16n8k64");
2209LogicalResult ShflOp::verify() {
2210 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2212 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2213 Type actualType) -> LogicalResult {
2214 return emitOpError(
"expected " + desc +
" to be of type ")
2215 << expectedType <<
" but got " << actualType <<
" instead";
2218 if (returnStructType) {
2219 if (!getReturnValueAndIsValid())
2220 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2221 "specified when the return type is a struct type");
2223 if (returnStructType.getBody().size() != 2)
2224 return emitOpError(
"expected return type to be a two-element struct");
2227 auto resultType = returnStruct[0];
2228 if (resultType != getVal().
getType())
2229 return verifyTypeError(
"first element in the returned struct",
2230 getVal().
getType(), resultType);
2232 auto predicateType = returnStruct[1];
2233 if (!predicateType.isInteger(1))
2234 return verifyTypeError(
"second element in the returned struct",
2238 if (getReturnValueAndIsValid())
2239 return emitOpError(
"expected return type to be a two-element struct");
2242 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2248 NVVM::MMAFrag frag,
int nRow,
2251 unsigned numberElements = 0;
2254 Type f16x2 = VectorType::get(2, builder.getF16Type());
2255 if (type == NVVM::MMATypes::f16) {
2256 elementType = f16x2;
2257 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2261 }
else if (type == NVVM::MMATypes::f32) {
2262 elementType = builder.getF32Type();
2264 }
else if (type == NVVM::MMATypes::f64) {
2265 elementType = builder.getF64Type();
2266 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2270 }
else if (type == NVVM::MMATypes::tf32) {
2271 elementType = builder.getI32Type();
2273 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2274 elementType = builder.getI32Type();
2275 int parallelSize = 0;
2276 if (frag == NVVM::MMAFrag::a)
2277 parallelSize = nRow;
2278 if (frag == NVVM::MMAFrag::b)
2279 parallelSize = nCol;
2282 if (parallelSize == 16)
2285 else if (parallelSize == 8)
2287 else if (parallelSize == 32)
2289 }
else if (type == NVVM::MMATypes::s32) {
2290 elementType = builder.getI32Type();
2293 assert(numberElements != 0 && elementType !=
nullptr);
2294 return std::make_pair(elementType, numberElements);
2297static std::pair<mlir::Type, unsigned>
2301 if (frag == NVVM::MMAFrag::a) {
2304 }
else if (frag == NVVM::MMAFrag::b) {
2311 assert(nRow && nCol);
2315LogicalResult NVVM::WMMALoadOp::verify() {
2316 unsigned addressSpace =
2317 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2318 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2319 addressSpace != NVVMMemorySpace::Shared)
2320 return emitOpError(
"expected source pointer in memory "
2323 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2324 getEltype(), getFrag()) == 0)
2325 return emitOpError() <<
"invalid attribute combination";
2330 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2332 return emitOpError(
"expected destination type to be f64");
2336 Type dstType = LLVM::LLVMStructType::getLiteral(
2339 return emitOpError(
"expected destination type is a structure of ")
2340 << typeInfo.second <<
" elements of type " << typeInfo.first;
2344LogicalResult NVVM::WMMAStoreOp::verify() {
2345 unsigned addressSpace =
2346 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2347 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2348 addressSpace != NVVMMemorySpace::Shared)
2349 return emitOpError(
"expected operands to be a source pointer in memory "
2352 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2354 return emitOpError() <<
"invalid attribute combination";
2357 if (getArgs().size() != typeInfo.second)
2358 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2359 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2360 return operands.
getType() != typeInfo.first;
2362 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2366LogicalResult NVVM::WMMAMmaOp::verify() {
2367 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2368 getLayoutB(), getEltypeA(),
2370 return emitOpError() <<
"invalid attribute combination";
2378 arguments.append(typeInfoA.second, typeInfoA.first);
2379 arguments.append(typeInfoB.second, typeInfoB.first);
2380 arguments.append(typeInfoC.second, typeInfoC.first);
2381 unsigned numArgs = arguments.size();
2382 if (getArgs().size() != numArgs)
2383 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2384 for (
unsigned i = 0; i < numArgs; i++) {
2385 if (getArgs()[i].
getType() != arguments[i])
2386 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2389 Type dstType = LLVM::LLVMStructType::getLiteral(
2392 return emitOpError(
"expected destination type is a structure of ")
2393 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2397LogicalResult NVVM::LdMatrixOp::verify() {
2399 if (m == 8 && n == 8) {
2400 if (num != 1 && num != 2 && num != 4) {
2401 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2404 if (getEltType() != LdStMatrixEltType::B16) {
2405 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2407 }
else if (m == 8 && n == 16) {
2408 if (num != 1 && num != 2 && num != 4) {
2409 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2412 if (getLayout() != MMALayout::row) {
2413 return emitOpError(
"expected layout to be row for 8x16 matrix");
2415 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2416 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2417 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2418 "b8x16.b6x16_p32 for 8x16 matrix");
2420 }
else if (m == 16 && n == 16) {
2421 if (num != 1 && num != 2) {
2422 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2425 if (getLayout() != MMALayout::col) {
2426 return emitOpError(
"expected layout to be col for 16x16 matrix");
2428 if (getEltType() != LdStMatrixEltType::B8 &&
2429 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2430 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2431 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2432 "b8x16.b6x16_p32 for 16x16 matrix");
2435 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2439 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2440 if (numElements == 1 &&
getType() != i32)
2441 return emitOpError(
"expected destination type is i32");
2442 if (numElements == 2 || numElements == 4) {
2443 Type dstType = LLVM::LLVMStructType::getLiteral(
2446 return emitOpError(
"expected destination type is a structure of ")
2447 << numElements <<
" elements of type i32";
2453LogicalResult NVVM::StMatrixOp::verify() {
2454 int numMatrix = getSources().size();
2455 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2456 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2459 if (m == 8 && n == 8) {
2460 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2461 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2463 }
else if (m == 16 && n == 8) {
2464 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2465 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2467 if (getLayout() != NVVM::MMALayout::col) {
2468 return emitOpError(
"expected layout to be col for 16x8 matrix");
2471 return emitOpError(
"expected shape to be 8x8 or 16x8");
2478 if (typeA == NVVM::WGMMATypes::tf32)
2480 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2482 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2484 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2486 if (typeA == NVVM::WGMMATypes::b1)
2492 NVVM::WGMMATypes typeA,
2493 NVVM::WGMMATypes typeB) {
2495 case NVVM::WGMMATypes::f16:
2496 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2497 typeB == NVVM::WGMMATypes::f16)
2500 case NVVM::WGMMATypes::tf32:
2501 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2504 case NVVM::WGMMATypes::u8:
2505 case NVVM::WGMMATypes::s8:
2506 if (typeD == NVVM::WGMMATypes::s32 &&
2507 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2510 case NVVM::WGMMATypes::b1:
2511 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2514 case NVVM::WGMMATypes::bf16:
2515 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2516 typeB == NVVM::WGMMATypes::bf16)
2519 case NVVM::WGMMATypes::e4m3:
2520 case NVVM::WGMMATypes::e5m2:
2521 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2522 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2525 case WGMMATypes::f32:
2526 case WGMMATypes::s32:
2527 llvm_unreachable(
"unsupported input types");
2535 72, 80, 88, 96, 104, 112, 120, 128,
2536 136, 144, 152, 160, 168, 176, 184, 192,
2537 200, 208, 216, 224, 232, 240, 248, 256};
2539 80, 96, 112, 128, 144, 160,
2540 176, 192, 208, 224, 240, 256};
2542 case WGMMATypes::f16:
2543 case WGMMATypes::tf32:
2544 case WGMMATypes::bf16:
2545 case WGMMATypes::e4m3:
2546 case WGMMATypes::e5m2:
2547 if (llvm::is_contained(allowedN, sizeN))
2550 case WGMMATypes::u8:
2551 case WGMMATypes::s8:
2552 case WGMMATypes::b1:
2553 if (llvm::is_contained(allowedNshort, sizeN))
2556 case WGMMATypes::f32:
2557 case WGMMATypes::s32:
2558 llvm_unreachable(
"unsupported input types");
2564LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2565 Value outValue = getResults();
2566 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2568 return emitOpError() <<
"expected results to be struct";
2569 int outputSize = stype.getBody().size();
2570 WGMMATypes typeD = getTypeD();
2571 WGMMATypes typeA = getTypeA();
2572 WGMMATypes typeB = getTypeB();
2574 for (
Type t : stype.getBody()) {
2575 if (t != stype.getBody().front())
2577 <<
"all elements in struct must be same type but there is " << t;
2580 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2581 typeD != WGMMATypes::s32) {
2582 return emitOpError() <<
"does not support the given output type " << typeD;
2584 if (typeD == WGMMATypes::s32 &&
2585 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2586 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2590 return emitOpError() << typeD <<
" += " << typeA <<
" * " << typeB
2591 <<
", it is not supported.";
2601 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2602 <<
" for input type " << typeA;
2606 return emitOpError() <<
"has input type " << typeA <<
" n is set to "
2607 <<
getShape().getN() <<
", it is not supported.";
2614 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2615 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2616 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2618 <<
"given layouts layout_a = " << getLayoutA()
2619 <<
" and layout_b = " << getLayoutB() <<
" for input types " << typeA
2621 <<
" requires transpose. However, this is only supported for: "
2622 << MMATypes::f16 <<
" and " << MMATypes::bf16;
2626 int expectedOutput = 0;
2627 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2628 expectedOutput =
getShape().getN() / 2;
2629 if (typeD == WGMMATypes::f16)
2630 expectedOutput =
getShape().getN() / 4;
2631 if (outputSize != expectedOutput) {
2632 return emitOpError() <<
"results " << expectedOutput
2633 <<
", however output struct has " << outputSize
2637 if (typeD != WGMMATypes::s32 &&
2638 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2639 NVVM::MMAIntOverflow::satfinite) {
2641 <<
" `satfinite` can be only used with s32 accumulator, however "
2642 "the current accumulator is "
2649std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2652 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2654 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2656 int expectedOutputRegisters = 0;
2657 if (getTypeD() == WGMMATypes::f16)
2658 expectedOutputRegisters =
getShape().getN() / 4;
2660 expectedOutputRegisters =
getShape().getN() / 2;
2663 llvm::raw_string_ostream ss(ptx);
2668 << ((expectedOutputRegisters * 2) + 2)
2670 "wgmma.mma_async.sync.aligned.m"
2671 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"." << getTypeA()
2672 <<
"." << getTypeB();
2673 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2674 NVVM::MMAIntOverflow::satfinite)
2678 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2679 ss <<
"$" << regCnt;
2680 if (regCnt != expectedOutputRegisters - 1)
2686 regCnt = (regCnt * 2);
2687 ss <<
" $" << (regCnt) <<
","
2688 <<
" $" << (regCnt + 1) <<
","
2690 if (getTypeD() != WGMMATypes::s32) {
2691 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2695 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2702bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2706 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2713 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2715 if (getTypeD() != WGMMATypes::s32) {
2716 asmValues.push_back(
2717 {makeConstantI32(rewriter,
2718 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2720 asmValues.push_back(
2721 {makeConstantI32(rewriter,
2722 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2726 asmValues.push_back(
2727 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2729 asmValues.push_back(
2730 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2736LogicalResult NVVM::FenceProxyOp::verify() {
2737 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2738 return emitOpError() <<
"async_shared fence requires space attribute";
2740 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2741 return emitOpError() <<
"only async_shared fence can have space attribute";
2746LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2747 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2748 return emitOpError(
"uni-directional proxies only support generic for "
2749 "from_proxy attribute");
2751 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2752 return emitOpError(
"uni-directional proxies only support tensormap "
2753 "for to_proxy attribute");
2757LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2758 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2759 return emitOpError(
"uni-directional proxies only support generic for "
2760 "from_proxy attribute");
2762 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2763 return emitOpError(
"uni-directional proxies only support tensormap "
2764 "for to_proxy attribute");
2768LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2769 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2770 return emitOpError(
"only generic is support for from_proxy attribute");
2772 if (getToProxy() != NVVM::ProxyKind::async)
2773 return emitOpError(
"only async is supported for to_proxy attribute");
2777LogicalResult NVVM::SetMaxRegisterOp::verify() {
2778 if (getRegCount() % 8)
2779 return emitOpError(
"new register size must be multiple of 8");
2780 if (getRegCount() < 24 || getRegCount() > 256)
2781 return emitOpError(
"new register size must be in between 24 to 256");
2785LogicalResult NVVM::BarrierOp::verify() {
2786 if (getNumberOfThreads() && !getBarrierId())
2788 "barrier id is missing, it should be set between 0 to 15");
2790 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
2791 return emitOpError(
"reduction are only available when id is 0");
2795 return emitOpError(
"reduction predicate and reduction operation must be "
2796 "specified together");
2801LogicalResult NVVM::Tcgen05CpOp::verify() {
2802 auto mc = getMulticast();
2804 using SH = Tcgen05CpShape;
2805 using MC = Tcgen05CpMulticast;
2807 case SH::SHAPE_128x256b:
2808 case SH::SHAPE_128x128b:
2809 case SH::SHAPE_4x256b:
2811 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2813 case SH::SHAPE_64x128b:
2814 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2815 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2816 "warpx2_02_13 for tcgen05.cp Op");
2818 case SH::SHAPE_32x128b:
2819 if (mc != MC::WARPX4)
2821 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2827LogicalResult NVVM::MatchSyncOp::verify() {
2828 if (getKind() == NVVM::MatchSyncKind::all) {
2829 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2830 if (!type || type.getBody().size() != 2 ||
2831 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2832 return emitOpError(
"match.sync 'all' returns a two element struct with "
2833 "first element as i32 and second element as i1");
2836 if (!
getType().isInteger(32)) {
2837 return emitOpError(
"match.sync 'any' returns an i32");
2843LogicalResult NVVM::VoteSyncOp::verify() {
2844 if (getKind() == NVVM::VoteSyncKind::ballot) {
2845 if (!
getType().isInteger(32)) {
2846 return emitOpError(
"vote.sync 'ballot' returns an i32");
2849 if (!
getType().isInteger(1)) {
2850 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
2856LogicalResult NVVM::PrefetchOp::verify() {
2857 using MemSpace = NVVM::NVVMMemorySpace;
2858 using CacheLevel = NVVM::PrefetchCacheLevel;
2860 unsigned addressSpace =
2861 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
2862 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2863 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2865 if (getTensormap() && cacheLevel)
2866 return emitOpError(
"cannot specify both tensormap and cache level");
2868 if (getTensormap()) {
2869 if (addressSpace != MemSpace::Generic &&
2870 addressSpace != MemSpace::Constant) {
2872 "prefetch tensormap requires a generic or constant pointer");
2875 if (evictPriority) {
2877 "prefetch tensormap does not support eviction priority");
2880 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2882 "in_param_space can only be specified for a generic pointer");
2885 }
else if (cacheLevel) {
2886 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2887 addressSpace != MemSpace::Local) {
2888 return emitOpError(
"prefetch to cache level requires a generic, global, "
2889 "or local pointer");
2893 if (*cacheLevel != CacheLevel::L1) {
2895 "unsupported cache level, the only supported uniform "
2896 "cache level is L1");
2899 if (addressSpace != MemSpace::Generic) {
2901 "prefetch to uniform cache requires a generic pointer");
2905 if (evictPriority) {
2906 if (*cacheLevel != CacheLevel::L2)
2908 "cache eviction priority supported only for cache level L2");
2910 if (addressSpace != MemSpace::Global)
2911 return emitOpError(
"cache eviction priority requires a global pointer");
2913 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2914 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2916 "unsupported cache eviction priority, only evict_last and "
2917 "evict_normal are supported");
2921 return emitOpError(
"predicate supported only on prefetch tensormap");
2925 "requires specification of either cache level or tensormap");
2931LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2932 switch (getQueryType()) {
2933 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2935 return emitOpError(
"is_canceled query type returns an i1");
2937 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2938 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2939 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2940 if (!
getType().isInteger(32)) {
2941 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
2942 "get_first_cta_id_z query types return an i32");
2949LogicalResult NVVM::ReduxOp::verify() {
2952 if (!reduxType.
isF32()) {
2954 return emitOpError(
"abs attribute is supported only for f32 type");
2956 return emitOpError(
"nan attribute is supported only for f32 type");
2959 NVVM::ReductionKind kind = getKind();
2961 case NVVM::ReductionKind::ADD:
2962 case NVVM::ReductionKind::AND:
2963 case NVVM::ReductionKind::OR:
2964 case NVVM::ReductionKind::XOR:
2965 case NVVM::ReductionKind::MAX:
2966 case NVVM::ReductionKind::MIN:
2967 case NVVM::ReductionKind::UMAX:
2968 case NVVM::ReductionKind::UMIN:
2971 << kind <<
"' reduction kind unsupported with " << reduxType
2972 <<
" type. Only supported type is 'i32'.";
2974 case NVVM::ReductionKind::FMIN:
2975 case NVVM::ReductionKind::FMAX:
2976 if (!reduxType.isF32())
2978 << kind <<
"' reduction kind unsupported with " << reduxType
2979 <<
" type. Only supported type is 'f32'.";
2986LogicalResult NVVM::TensormapReplaceOp::verify() {
2987 auto ord = getOrd();
2988 Value newVal = getNewValue();
2989 auto newValAttr = getNewValueAttr();
2990 auto fieldName = stringifyEnum(getField());
2992 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
2993 NVVM::TensormapField::GLOBAL_DIM,
2994 NVVM::TensormapField::GLOBAL_STRIDE,
2995 NVVM::TensormapField::ELEMENT_STRIDE},
2997 return emitOpError(
"ordinal is not supported for ")
2998 << fieldName <<
" field";
3000 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3001 return llvm::Twine(
"new_value must be specified and must be an " + type +
3002 " for " + llvm::Twine(fieldName) +
" field")
3006 auto invalidNewValAttr = [&]() -> std::string {
3007 return (llvm::Twine(
3008 "new_value_attr must be specified and must be a valid ") +
3009 llvm::Twine(fieldName) +
" attribute for " + fieldName +
" field")
3013 switch (getField()) {
3014 case NVVM::TensormapField::GLOBAL_ADDRESS:
3018 case NVVM::TensormapField::RANK:
3022 case NVVM::TensormapField::GLOBAL_STRIDE:
3024 return emitOpError(
"ordinal is required for global_stride field");
3028 case NVVM::TensormapField::BOX_DIM:
3029 case NVVM::TensormapField::GLOBAL_DIM:
3030 case NVVM::TensormapField::ELEMENT_STRIDE:
3033 << stringifyEnum(getField()) <<
" field";
3037 case NVVM::TensormapField::ELEMTYPE:
3038 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3041 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3042 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3045 case NVVM::TensormapField::SWIZZLE_MODE:
3046 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3049 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3050 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3053 case NVVM::TensormapField::FILL_MODE:
3054 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3062template <
typename OpType>
3064 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3065 mlir::NVVM::SaturationMode satMode = op.getSat();
3066 bool isFTZ = op.getFtz();
3069 mlir::Type opBaseType = isa<VectorType>(opType)
3070 ? cast<VectorType>(opType).getElementType()
3073 if (opBaseType.
isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3074 return op.emitOpError(
"FTZ and saturation are not supported for "
3075 "additions/subtractions involving f64 type");
3077 if (opBaseType.
isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3078 rndMode == NVVM::FPRoundingMode::NONE))
3079 return op.emitOpError(
"only RN rounding mode is supported for f16 and "
3080 "vector<2xf16> additions/subtractions");
3082 if (opBaseType.
isBF16()) {
3083 if (rndMode != NVVM::FPRoundingMode::RN &&
3084 rndMode != NVVM::FPRoundingMode::NONE)
3085 return op.emitOpError(
"only RN rounding mode is supported for bf16 and "
3086 "vector<2xbf16> additions/subtractions");
3087 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3088 return op.emitOpError(
"FTZ and saturation are not supported for bf16 and "
3089 "vector<2xbf16> additions/subtractions");
3096 if (opBaseType.
isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3097 return op.emitOpError(
"FTZ with no saturation is not supported for f16 and "
3098 "vector<2xf16> additions/subtractions");
3107LogicalResult NVVM::FmaOp::verify() {
3108 auto opType = getRes().getType();
3109 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3110 mlir::NVVM::SaturationMode satMode = getSat();
3111 bool isFTZ = getFtz();
3112 bool isRelu = getRelu();
3113 bool hasOOB = getOob();
3115 auto getBaseFType = [](
Type type) ->
Type {
3116 if (isa<VectorType>(type))
3117 return cast<VectorType>(type).getElementType();
3121 auto opBaseType = getBaseFType(opType);
3123 if (rndMode == NVVM::FPRoundingMode::NONE)
3124 return emitOpError(
"rounding mode must be specified");
3126 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3127 return emitOpError(
"relu and saturation are not supported together");
3129 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3130 return emitOpError(
"oob is not supported with saturation or FTZ");
3132 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3133 return emitOpError(
"relu and oob are only supported for f16 and bf16");
3135 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3136 return emitOpError(
"FTZ and saturation are not supported for f64 type");
3138 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3140 "only RN rounding mode is supported for f16 and vector<2xf16>");
3142 if (opBaseType.isBF16()) {
3143 if (rndMode != NVVM::FPRoundingMode::RN)
3145 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3146 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3148 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3160 unsigned sizeInBits,
3162 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3164 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3165 if (mask != 0xffffffffu)
3166 field = builder.CreateAnd(field, builder.getInt32(mask));
3168 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3169 field = builder.CreateShl(field, start);
3171 return builder.CreateOr(
result, field);
3174void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3176 llvm::IRBuilderBase &builder) {
3177 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3178 llvm::Value *smemDesc = builder.getInt64(0);
3183 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3185 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3191 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3195 mt.
mapValue(thisOp.getRes()) = smemDesc;
3202std::string NVVM::MBarrierInitOp::getPtx() {
3204 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3205 : std::string(
"mbarrier.init.b64 [%0], %1;");
3208std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3211 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3212 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3215std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3217 llvm::StringRef space = isShared ?
".shared" :
"";
3219 return llvm::formatv(
"{\n\t"
3220 ".reg .pred P1; \n\t"
3222 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3223 "@P1 bra.uni DONE; \n\t"
3224 "bra.uni LAB_WAIT; \n\t"
3241 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3244 op.getRnd(), op.getSat(), op.getFtz());
3260 auto thisOp = cast<NVVM::BarrierOp>(op);
3261 llvm::Value *barrierId = thisOp.getBarrierId()
3263 : builder.getInt32(0);
3264 llvm::Intrinsic::ID id;
3266 if (thisOp.getNumberOfThreads()) {
3267 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3268 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3269 }
else if (thisOp.getReductionOp()) {
3270 switch (*thisOp.getReductionOp()) {
3271 case NVVM::BarrierReduction::AND:
3272 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3274 case NVVM::BarrierReduction::OR:
3275 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3277 case NVVM::BarrierReduction::POPC:
3278 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3281 args.push_back(builder.CreateICmpNE(
3282 mt.
lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3284 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3287 return {id, std::move(args)};
3292 llvm::IRBuilderBase &builder) {
3293 auto thisOp = cast<NVVM::PMEventOp>(op);
3297 llvm::Value *maskVal;
3298 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3299 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3300 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3303 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3306 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3311 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3313 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3314 : llvm::Intrinsic::nvvm_mbarrier_init;
3319 args.push_back(mt.
lookupValue(thisOp.getCount()));
3321 return {id, std::move(args)};
3326 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3328 llvm::Intrinsic::ID
id = isShared
3329 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3330 : llvm::Intrinsic::nvvm_mbarrier_inval;
3337 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3340 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3343 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3345 static constexpr llvm::Intrinsic::ID IDs[] = {
3346 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3347 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3348 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3349 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3354 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3356 return {IDs[
index], std::move(args)};
3361 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3364 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3367 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3369 static constexpr llvm::Intrinsic::ID IDs[] = {
3370 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3371 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3372 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3373 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3378 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3380 return {IDs[
index], std::move(args)};
3385 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3388 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3391 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3393 static constexpr llvm::Intrinsic::ID IDs[] = {
3394 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3395 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3396 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3397 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3398 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3399 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3400 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3401 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3403 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3404 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3408 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3415 bool hasCount =
static_cast<bool>(thisOp.getCount());
3417 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3418 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3422 llvm::Value *count =
3424 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3425 return {id, {mbar, count}};
3430 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3433 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3436 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3438 static constexpr llvm::Intrinsic::ID IDs[] = {
3439 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3440 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3441 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3442 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3443 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3444 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3446 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3448 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3450 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3451 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3455 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3461 bool hasCount =
static_cast<bool>(thisOp.getCount());
3462 llvm::Value *count =
3464 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3466 return {id, {mbar, count}};
3469bool MBarrierArriveExpectTxOp::getAsmValues(
3476 for (
auto val : getOperands())
3484 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3487 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3490 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3493 static constexpr llvm::Intrinsic::ID IDs[] = {
3494 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3495 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3496 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3497 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3498 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3499 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3500 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3501 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3502 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3504 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3507 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3508 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3513 return {id, {mbar, txcount}};
3518 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3521 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3524 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3527 static constexpr llvm::Intrinsic::ID IDs[] = {
3528 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3530 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3531 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3532 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3533 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3534 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3535 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3536 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3538 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3541 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3542 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3547 return {id, {mbar, txcount}};
3552 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3554 llvm::Intrinsic::ID
id =
3555 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3556 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3560 args.push_back(mt.
lookupValue(thisOp.getCount()));
3562 return {id, std::move(args)};
3567 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3569 llvm::Intrinsic::ID
id =
3570 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3571 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3575 args.push_back(mt.
lookupValue(thisOp.getCount()));
3577 return {id, std::move(args)};
3582 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3583 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3584 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3587 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3590 static constexpr llvm::Intrinsic::ID IDs[] = {
3591 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3592 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3593 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3594 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3595 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3596 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3597 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3598 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3599 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3601 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3604 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3605 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3610 return {id, {mbar, input}};
3615 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3616 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3617 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3618 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3622 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3623 (isPhaseParity ? 1 : 0);
3626 static constexpr llvm::Intrinsic::ID IDs[] = {
3627 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3628 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3629 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3630 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3631 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3632 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3633 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3634 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3635 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3636 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3637 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3638 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3639 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3640 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3641 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3642 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3643 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3645 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3648 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3655 args.push_back(mbar);
3656 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
3658 args.push_back(mt.
lookupValue(thisOp.getTicks()));
3660 return {id, std::move(args)};
3665 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3668 llvm::Intrinsic::ID id;
3669 if (thisOp.getNoinc()) {
3670 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3671 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3673 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3674 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3680#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3681 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3683#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3684 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3689 llvm::Intrinsic::ID id;
3691 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3692 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
3693 switch (cpAsyncOp.getSize()) {
3701 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3706 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
3710 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
3711 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
3713 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
3720 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3722 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3725 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3729 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3730 llvm::Value *i64Unused =
3731 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3732 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3733 args.push_back(builder.getInt1(hasCacheHint));
3735 return {id, std::move(args)};
3740 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3744 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3746 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3750 mlir::Value multicastMask = thisOp.getMulticastMask();
3751 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
3754 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3755 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
3761 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3762 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3763 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3767 args.push_back(builder.getInt1(hasMulticastMask));
3768 args.push_back(builder.getInt1(hasCacheHint));
3770 llvm::Intrinsic::ID
id =
3772 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3773 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3775 return {id, std::move(args)};
3780 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3782 llvm::Intrinsic::ID
id =
3783 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3786 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3787 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3791 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3792 llvm::Value *i64Unused =
3793 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3794 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3795 args.push_back(builder.getInt1(hasCacheHint));
3798 if (
mlir::Value byteMask = thisOp.getByteMask()) {
3800 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3803 return {id, std::move(args)};
3806bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3813 for (
auto val : getOperands())
3820CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3822 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3823 const bool isCTAOnly = thisOp.getIsCTAOnly();
3827 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3829 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3839 const bool hasMC =
static_cast<bool>(mcMask);
3840 llvm::Value *i16Zero =
3841 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
3845 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3846 llvm::Value *i64Zero =
3847 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3853 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
3855 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
3859 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
3860 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3861 args.push_back(builder.getInt1(hasMC));
3862 args.push_back(builder.getInt1(hasCacheHint));
3866 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3867 args.push_back(builder.getInt1(hasCacheHint));
3870 constexpr size_t numDims = 5;
3871 constexpr size_t numModes = 5;
3872 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3873 using TableTy = std::array<rowTy, numModes>;
3874 static constexpr TableTy IDTable{
3875 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3876 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3877 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3878 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3879 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3881 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3882 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3883 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3885 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3886 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3887 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3889 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3890 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3891 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3893 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3895 static constexpr TableTy IDTableCTA{
3897 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3898 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3899 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3900 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3901 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3907 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3913 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3915 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3918 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3919 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3920 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3921 size_t mode =
static_cast<size_t>(thisOp.getMode());
3922 size_t dim = thisOp.getCoordinates().size();
3923 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3925 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3927 return {id, std::move(args)};
3932 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3936 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3938 for (
auto v : thisOp.getCoordinates())
3940 for (
auto v : thisOp.getIm2colOffsets())
3944 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3945 llvm::Value *i64Unused =
3946 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3947 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3948 args.push_back(builder.getInt1(hasCacheHint));
3950 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3951 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3952 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3953 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3955 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3956 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3958 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3964 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3968 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3969 {NI, NI, NI, NI, NI,
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3972 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3973 "TMALoadModes must match number of rows in IDTable");
3974 size_t mode =
static_cast<size_t>(thisOp.getMode());
3975 size_t dim = thisOp.getCoordinates().size();
3976 llvm::Intrinsic::ID
id = IDTable[mode][dim];
3977 if (
id == llvm::Intrinsic::not_intrinsic)
3978 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3980 return {id, std::move(args)};
3984CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3986 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3990 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3991 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3993 for (
auto v : thisOp.getCoordinates())
3997 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3998 llvm::Value *i64Unused =
3999 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4000 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4001 args.push_back(builder.getInt1(hasCacheHint));
4003 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4004 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4005 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4006 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4007 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4008 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4009 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4010 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4011 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4012 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4013 {NI, NI, NI, NI, NI,
4014 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4016 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4017 "TMAStoreModes must match number of rows in IDTable");
4018 size_t mode =
static_cast<size_t>(thisOp.getMode());
4019 size_t dim = thisOp.getCoordinates().size();
4020 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4021 if (
id == llvm::Intrinsic::not_intrinsic)
4023 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4025 return {id, std::move(args)};
4030 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4038 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4039 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4041 for (
Value v : thisOp.getCoordinates())
4045 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4046 llvm::Value *i64ZeroValue =
4047 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4048 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
4049 args.push_back(builder.getInt1(hasCacheHint));
4051 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
4053 constexpr unsigned numRedKinds = 8;
4054 constexpr unsigned numLayouts = 2;
4055 constexpr unsigned maxDim = 5;
4056 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4057 using layoutTable = std::array<row, numLayouts>;
4058 using fullTable = std::array<layoutTable, numRedKinds>;
4059 static constexpr fullTable IDTable{
4062 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4063 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4064 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4065 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4066 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4068 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4069 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4070 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4073 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4074 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4075 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4076 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4077 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4079 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4080 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4081 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4084 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4085 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4086 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4087 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4088 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4090 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4091 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4092 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4095 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4096 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4097 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4098 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4099 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4101 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4102 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4107 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4109 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4110 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4119 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4120 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4121 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4123 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4129 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4143 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4148 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4150 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4151 "TMAReduxKinds must match number of rows in IDTable");
4153 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
4154 size_t mode =
static_cast<size_t>(thisOp.getMode());
4155 size_t dim = thisOp.getCoordinates().size();
4157 assert(redKind < IDTable.size() &&
4158 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4159 assert(mode < IDTable[redKind].size() &&
4160 "Invalid mode for CpAsyncBulkTensorReduceOp");
4161 assert(dim < IDTable[redKind][mode].size() &&
4162 "Invalid dim for CpAsyncBulkTensorReduceOp");
4164 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4167 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4169 return {intrinsicID, std::move(args)};
4174#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4175 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4176 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4178#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4179 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4180 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4183ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4184 NVVM::SaturationMode sat,
bool hasRelu) {
4185 using RndMode = NVVM::FPRoundingMode;
4186 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4195 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4200ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4202 llvm::IRBuilderBase &builder) {
4207 bool hasRelu = op.getRelu();
4209 llvm::Intrinsic::ID intId =
4210 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4211 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4213 return {intId, std::move(args)};
4216#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4217 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4218 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4220llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4223 .Case([&](mlir::Float6E2M3FNType) {
4226 .Case([&](mlir::Float6E3M2FNType) {
4230 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4231 return llvm::Intrinsic::not_intrinsic;
4235#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4236 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4237 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4239#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4240 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4241 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4244ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4245 NVVM::SaturationMode sat,
bool hasRelu) {
4246 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4247 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4248 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4251 .Case([&](mlir::Float8E4M3FNType) {
4254 .Case([&](mlir::Float8E5M2Type) {
4257 .Case([&](mlir::Float8E8M0FNUType) {
4258 if (hasRoundingModeRZ)
4260 else if (hasRoundingModeRP)
4263 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4266 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4267 return llvm::Intrinsic::not_intrinsic;
4271#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4272 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4273 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4275llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4278 .Case([&](mlir::Float8E4M3FNType) {
4281 .Case([&](mlir::Float8E5M2Type) {
4285 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4286 return llvm::Intrinsic::not_intrinsic;
4290#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4291 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4292 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4295ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4296 NVVM::SaturationMode sat) {
4297 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4299 case NVVM::FPRoundingMode::RZ:
4301 case NVVM::FPRoundingMode::RP:
4304 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
4310 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4312 bool hasRelu = curOp.getRelu();
4314 llvm::Intrinsic::ID intId =
4316 .Case([&](Float8E4M3FNType type) {
4317 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4318 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4320 .Case([&](Float8E5M2Type type) {
4321 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4322 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4325 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4326 return llvm::Intrinsic::not_intrinsic;
4329 llvm::Value *packedI16 =
4330 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4331 llvm::Type::getInt16Ty(builder.getContext()));
4333 return {intId, {packedI16}};
4338 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4340 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4341 llvm::Value *packedI16 =
4342 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4343 llvm::Type::getInt16Ty(builder.getContext()));
4345 return {intId, {packedI16}};
4350 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4352 bool hasRelu = curOp.getRelu();
4354 llvm::Intrinsic::ID intId =
4356 .Case([&](Float6E2M3FNType type) {
4357 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4358 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4360 .Case([&](Float6E3M2FNType type) {
4361 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4362 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4365 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4366 return llvm::Intrinsic::not_intrinsic;
4369 llvm::Value *packedI16 =
4370 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4371 llvm::Type::getInt16Ty(builder.getContext()));
4373 return {intId, {packedI16}};
4378 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4380 bool hasRelu = curOp.getRelu();
4382 llvm::Intrinsic::ID intId =
4384 .Case([&](Float4E2M1FNType type) {
4385 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4386 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4389 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4390 return llvm::Intrinsic::not_intrinsic;
4393 llvm::Value *extendedI16 =
4394 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4395 llvm::Type::getInt16Ty(builder.getContext()));
4397 return {intId, {extendedI16}};
4401Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
4404 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4405 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4407 bool isShared = as == NVVMMemorySpace::Shared;
4408 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4410 llvm::Intrinsic::ID id;
4412 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4413 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4415 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4416 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4426llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4429 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4430 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4431 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4432 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4441#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4442 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4443 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4445#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4446 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4447 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4450Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
4453 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4454 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4456 bool isShared = as == NVVMMemorySpace::Shared;
4457 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
4458 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4460 llvm::Intrinsic::ID
id =
4467 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
4472#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4473 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4475#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4476 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4477 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4479#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4481 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4482 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4483 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4484 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4485 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4489ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4491 llvm::IRBuilderBase &builder) {
4492 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4493 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4494 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4495 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4496 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4498 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4499 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4500 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4501 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4502 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4504 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4505 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4506 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4507 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4508 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4511 unsigned hasRelu = op.getRelu() ? 1 : 0;
4512 unsigned hasSatFinite =
4513 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4516 unsigned idx = (hasSatFinite << 1) | hasRelu;
4521 if (op.getRandomBits())
4522 args.push_back(mt.
lookupValue(op.getRandomBits()));
4524 switch (op.getRnd()) {
4525 case FPRoundingMode::RN:
4526 return {rndRNIds[idx], std::move(args)};
4527 case FPRoundingMode::RZ:
4528 return {rndRZIds[idx], std::move(args)};
4529 case FPRoundingMode::RS:
4530 return {rndRSIds[idx], std::move(args)};
4532 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
4537ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4539 llvm::IRBuilderBase &builder) {
4540 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4541 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4542 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4543 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4544 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4546 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4547 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4548 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4549 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4550 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4552 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4553 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4554 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4555 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4556 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4559 unsigned hasRelu = op.getRelu() ? 1 : 0;
4560 unsigned hasSatFinite =
4561 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4564 unsigned idx = (hasSatFinite << 1) | hasRelu;
4569 if (op.getRandomBits())
4570 args.push_back(mt.
lookupValue(op.getRandomBits()));
4572 switch (op.getRnd()) {
4573 case FPRoundingMode::RN:
4574 return {rndRNIds[idx], std::move(args)};
4575 case FPRoundingMode::RZ:
4576 return {rndRZIds[idx], std::move(args)};
4577 case FPRoundingMode::RS:
4578 return {rndRSIds[idx], std::move(args)};
4580 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4584llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4586 bool hasRelu = getRelu();
4589 .Case([&](mlir::Float8E4M3FNType) {
4590 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4591 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4593 .Case([&](mlir::Float8E5M2Type) {
4594 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4595 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4598 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
4599 return llvm::Intrinsic::not_intrinsic;
4603llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4605 bool hasRelu = getRelu();
4608 .Case([&](mlir::Float6E2M3FNType) {
4609 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4610 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4612 .Case([&](mlir::Float6E3M2FNType) {
4613 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4614 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4617 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
4618 return llvm::Intrinsic::not_intrinsic;
4622llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4624 bool hasRelu = getRelu();
4627 .Case([&](mlir::Float4E2M1FNType) {
4628 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4629 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4632 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
4633 return llvm::Intrinsic::not_intrinsic;
4637llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
4638 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4639 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4640 auto srcFmt = curOp.getSrcFormat();
4641 auto mc = curOp.getMulticast();
4643 switch (curOp.getShape()) {
4644 case Tcgen05CpShape::SHAPE_128x256b:
4646 case Tcgen05CpShape::SHAPE_128x128b:
4648 case Tcgen05CpShape::SHAPE_4x256b:
4650 case Tcgen05CpShape::SHAPE_32x128b:
4652 case Tcgen05CpShape::SHAPE_64x128b:
4653 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4657 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
4664 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4666 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4671LogicalResult Tcgen05LdOp::verify() {
4673 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4676 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4677 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
4679 auto resTy = getRes().getType();
4680 unsigned resLen = isa<VectorType>(resTy)
4681 ? llvm::cast<VectorType>(resTy).getNumElements()
4684 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
4685 "{1} in tcgen05.ld Op",
4686 resLen, stringifyEnum(
getShape())));
4691LogicalResult Tcgen05StOp::verify() {
4693 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4696 auto valTy = getVal().getType();
4697 unsigned valLen = isa<VectorType>(valTy)
4698 ? llvm::cast<VectorType>(valTy).getNumElements()
4701 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
4702 "{1} in tcgen05.st Op",
4703 valLen, stringifyEnum(
getShape())));
4713 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
4714 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4715 rangeAttr.getLower(), rangeAttr.getUpper()});
4725 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4729 const llvm::APInt &lower = rangeAttr->getLower();
4730 const llvm::APInt &upper = rangeAttr->getUpper();
4733 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4734 unsigned bitWidth = lower.getBitWidth();
4735 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4736 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4738 "invalid range attribute: Lower == Upper, but they aren't min (")
4739 << llvm::toString(minVal, 10,
false) <<
") or max ("
4740 << llvm::toString(maxVal, 10,
false)
4741 <<
") value! This is an invalid constant range.";
4748 llvm::IRBuilderBase &builder) {
4749 return builder.CreateBitCast(arg,
4750 llvm::Type::getInt32Ty(builder.getContext()));
4755 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4762 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4763 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4764 unsigned type = (isASigned << 1) | isBSigned;
4765 const llvm::Intrinsic::ID ids[] = {
4766 llvm::Intrinsic::nvvm_idp4a_u_u,
4767 llvm::Intrinsic::nvvm_idp4a_u_s,
4768 llvm::Intrinsic::nvvm_idp4a_s_u,
4769 llvm::Intrinsic::nvvm_idp4a_s_s,
4771 return {ids[type], args};
4776 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4781 args.push_back(builder.getInt1(curOp.getBHi()));
4784 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4785 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4786 unsigned type = (isASigned << 1) | isBSigned;
4787 const llvm::Intrinsic::ID ids[] = {
4788 llvm::Intrinsic::nvvm_idp2a_u_u,
4789 llvm::Intrinsic::nvvm_idp2a_u_s,
4790 llvm::Intrinsic::nvvm_idp2a_s_u,
4791 llvm::Intrinsic::nvvm_idp2a_s_s,
4793 return {ids[type], args};
4797 llvm::IRBuilderBase &builder) {
4798 return builder.CreateAddrSpaceCast(
4800 llvm::PointerType::get(builder.getContext(),
4801 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4805PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4807 llvm::IRBuilderBase &builder) {
4808 using MemSpace = NVVM::NVVMMemorySpace;
4809 using CacheLevel = NVVM::PrefetchCacheLevel;
4811 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4812 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4813 op.getEvictPriority();
4814 unsigned addressSpace =
4815 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4823 if (op.getTensormap())
4824 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4826 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
4828 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4829 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4831 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4832 switch (*evictPriority) {
4833 case NVVM::CacheEvictionPriority::EvictLast:
4834 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4835 case NVVM::CacheEvictionPriority::EvictNormal:
4836 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4838 llvm_unreachable(
"Invalid cache eviction priority");
4842 switch (
static_cast<MemSpace
>(addressSpace)) {
4843 case MemSpace::Generic:
4844 return *cacheLevel == CacheLevel::L1
4846 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4847 case MemSpace::Global:
4848 return *cacheLevel == CacheLevel::L1
4850 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4852 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4853 case MemSpace::Local:
4854 return *cacheLevel == CacheLevel::L1
4856 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4858 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4860 llvm_unreachable(
"Invalid pointer address space");
4864bool NVVM::InlinePtxOp::getAsmValues(
4868 for (
auto arg : getReadWriteArgs())
4870 for (
auto arg : getResults())
4872 for (
auto arg : getReadOnlyArgs())
4879NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4881 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4883 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
4884 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
4886 llvm::Intrinsic::ID intrinsicID =
4887 curOp.getMulticast()
4889 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4890 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4892 return {intrinsicID, args};
4895NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4897 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4899 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
4901 llvm::Intrinsic::ID intrinsicID;
4903 switch (curOp.getQueryType()) {
4904 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4906 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4908 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4909 intrinsicID = llvm::Intrinsic::
4910 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4912 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4913 intrinsicID = llvm::Intrinsic::
4914 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4916 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4917 intrinsicID = llvm::Intrinsic::
4918 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4921 return {intrinsicID, args};
4926 llvm::IRBuilderBase &builder) {
4927 auto thisOp = cast<NVVM::PermuteOp>(op);
4928 NVVM::PermuteMode mode = thisOp.getMode();
4930 static constexpr llvm::Intrinsic::ID IDs[] = {
4931 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4932 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4933 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4934 llvm::Intrinsic::nvvm_prmt_rc16};
4936 unsigned modeIndex =
static_cast<unsigned>(mode);
4944 args.push_back(mt.
lookupValue(thisOp.getSelector()));
4946 return {IDs[modeIndex], args};
4951 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
4955 if (thisOp.getOrd())
4956 args.push_back(builder.getInt32(thisOp.getOrd().value()));
4957 if (thisOp.getNewValue())
4958 args.push_back(mt.
lookupValue(thisOp.getNewValue()));
4959 if (
auto attr = thisOp.getNewValueAttr()) {
4962 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
4963 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
4964 TensormapFillModeAttr>([](
auto attr) {
4965 return static_cast<unsigned>(attr.getValue());
4967 .Default([](
auto attr) {
4968 llvm_unreachable(
"Invalid attribute type");
4971 args.push_back(builder.getInt32(val));
4974 static constexpr llvm::Intrinsic::ID IDs[] = {
4975 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
4976 llvm::Intrinsic::nvvm_tensormap_replace_rank,
4977 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
4978 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
4979 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
4980 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
4981 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
4982 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
4983 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
4984 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
4985 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
4988 unsigned fieldIndex =
static_cast<unsigned>(thisOp.getField());
4990 return {IDs[fieldIndex], args};
4999 llvm::IRBuilderBase &builder) {
5001 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5004 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5007 const bool isATensor = isa<llvm::PointerType>(
A->getType());
5010 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5011 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5012 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5014 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5015 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5016 using IsATensorArray = std::array<CtaGroupArray, 2>;
5017 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5018 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5021 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5027 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
5029 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
5033 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5034 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5038 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5039 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5045 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
5047 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
5051 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5052 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5056 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5057 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5063 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5066 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5071 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5073 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5078 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5080 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5086 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5090 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5095 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5097 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5101 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5103 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5106 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5107 bool hasScaleInputD = ScaleInputD !=
nullptr;
5109 llvm::Value *DisableOutputLane =
5111 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5113 const unsigned ctaGroup =
5116 llvm::Intrinsic::ID ID =
5117 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5118 [ctaGroup - 1][thisOp.getAShift()];
5120 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
5123 args.push_back(ScaleInputD);
5125 if (hasDisableOutputLane)
5126 args.push_back(DisableOutputLane);
5128 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5130 if (!hasDisableOutputLane)
5131 args.push_back(builder.getInt32(ctaGroup));
5134 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5141 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
5142 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
5144 if (disableOutputLane) {
5145 mlir::VectorType disableOutputLaneType =
5146 cast<mlir::VectorType>(disableOutputLane.
getType());
5147 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5148 disableOutputLaneType.getNumElements() != 4) ||
5149 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5150 disableOutputLaneType.getNumElements() != 8))
5151 return emitError(loc) <<
"Disable Output Lane of length "
5152 << disableOutputLaneType.getNumElements()
5153 <<
" is incompatible with CtaGroupAttr";
5156 if (hasAShift && !isATensor)
5158 loc,
"A-shift can be applied only when matrix A is in tensor memory");
5160 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5161 collectorOp == Tcgen05MMACollectorOp::USE))
5163 loc,
"Cannot use collector buffer operation fill or use with ashift");
5168LogicalResult Tcgen05MMAOp::verify() {
5170 getDisableOutputLane(), getCtaGroup(), getAShift(),
5171 getCollectorOp(), getLoc());
5181 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5184 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5187 bool isATensor = isa<llvm::PointerType>(
A->getType());
5190 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5191 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5192 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5193 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5195 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5196 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5197 using IsATensorArray = std::array<CtaGroupArray, 2>;
5198 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5199 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5202 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5208 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
5210 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
5214 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5215 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5219 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5220 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5226 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5229 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5234 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5235 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5239 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5240 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5247 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5251 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5256 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5258 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5263 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5265 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5271 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5275 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5280 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5282 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5286 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5288 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5291 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5292 bool hasScaleInputD = ScaleInputD !=
nullptr;
5294 llvm::Value *DisableOutputLane =
5296 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5301 llvm::Intrinsic::ID ID =
5302 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5303 [ctaGroup - 1][thisOp.getAShift()];
5305 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5308 args.push_back(ScaleInputD);
5310 if (hasDisableOutputLane)
5311 args.push_back(DisableOutputLane);
5313 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5315 if (!hasDisableOutputLane)
5316 args.push_back(builder.getInt32(ctaGroup));
5319 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5324LogicalResult Tcgen05MMASparseOp::verify() {
5326 getDisableOutputLane(), getCtaGroup(), getAShift(),
5327 getCollectorOp(), getLoc());
5337 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5340 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5343 bool isATensor = isa<llvm::PointerType>(
A->getType());
5346 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5347 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5348 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5349 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5350 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5351 args.push_back(builder.getInt32(
5354 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5356 auto kind = thisOp.getKind();
5357 auto blockScale = thisOp.getBlockScale();
5358 llvm::Intrinsic::ID ID = [&]() {
5359 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5360 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5361 return isATensor ? llvm::Intrinsic::
5362 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5364 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5365 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5368 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5370 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5372 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5373 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5375 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5376 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5377 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5378 return isATensor ? llvm::Intrinsic::
5379 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5381 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5383 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5384 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5387 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5389 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5391 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5394 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5396 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5399 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
5406 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5407 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
5408 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5409 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5410 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
5412 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5413 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5415 llvm::formatv(
"{} kind does not support block16 attribute",
5416 stringifyEnum(kind)));
5421LogicalResult Tcgen05MMABlockScaleOp::verify() {
5423 getBlockScale(), getLoc());
5433 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5436 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5439 bool isATensor = isa<llvm::PointerType>(
A->getType());
5442 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5443 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5444 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5445 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5446 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5447 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5448 args.push_back(builder.getInt32(
5451 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5453 auto kind = thisOp.getKind();
5454 auto blockScale = thisOp.getBlockScale();
5455 llvm::Intrinsic::ID ID = [&]() {
5456 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5457 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5458 return isATensor ? llvm::Intrinsic::
5459 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5461 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5462 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5465 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5467 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5469 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5470 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5471 return isATensor ? llvm::Intrinsic::
5472 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5474 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5475 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5478 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5480 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5482 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5483 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5486 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5488 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5490 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5493 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5495 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5498 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
5504LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5506 getBlockScale(), getLoc());
5516 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5519 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5522 bool isATensor = isa<llvm::PointerType>(
A->getType());
5525 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5526 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5527 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5529 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5533 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5534 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5536 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5537 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5539 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5541 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5543 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5555 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5558 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5561 bool isATensor = isa<llvm::PointerType>(
A->getType());
5564 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5565 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5566 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5567 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5569 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5574 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5575 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5577 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5578 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5580 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5582 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5584 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5593#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5594 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5598 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5601 mlir::VectorType VecResTy =
5602 cast<mlir::VectorType>(thisOp.getData().getType());
5603 unsigned Num = VecResTy.getNumElements();
5604 bool IsFloat = thisOp.getRedVal().getType().isF32();
5606 llvm::Intrinsic::ID Shape32x32b[][2] = {
5617 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5628 NVVM::Tcgen05LdStShape
shape = thisOp.getShape();
5629 unsigned ID = [&]() {
5632 unsigned idx = std::log2(Num);
5634 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5635 return Shape32x32b[idx][IsFloat];
5636 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5637 return Shape16x32bx2[idx][IsFloat];
5639 llvm_unreachable(
"unhandled tcgen05.ld lowering");
5645 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5646 args.push_back(mt.
lookupValue(thisOp.getOffset()));
5649 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5652 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getAbs())));
5653 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getNan())));
5658LogicalResult Tcgen05LdRedOp::verify() {
5659 VectorType data = cast<VectorType>(getData().
getType());
5660 Type redVal = getRedVal().getType();
5662 if (data.getElementType() != redVal)
5664 "type of reduction value and element type of vector data should match");
5666 if (getOp() != NVVM::ReductionKind::MIN &&
5667 getOp() != NVVM::ReductionKind::MAX)
5668 return emitError(
"only min and max reduction kinds are supported");
5670 if (redVal.
isInteger() && (getAbs() || getNan())) {
5671 return emitError(
"abs or nan is only applicable for f32 type");
5681void NVVMDialect::initialize() {
5684#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5687#define GET_ATTRDEF_LIST
5688#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5693 allowUnknownOperations();
5694 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5695 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5698LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
5700 StringAttr attrName = attr.
getName();
5702 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5703 if (!isa<LLVM::LLVMFuncOp>(op)) {
5704 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
5705 <<
"' attribute attached to unexpected op";
5710 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5711 attrName == NVVMDialect::getReqntidAttrName() ||
5712 attrName == NVVMDialect::getClusterDimAttrName()) {
5713 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
5714 if (!values || values.empty() || values.size() > 3) {
5717 <<
"' attribute must be integer array with maximum 3 index";
5722 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5723 attrName == NVVMDialect::getMaxnregAttrName() ||
5724 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5725 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
5727 <<
"'" << attrName <<
"' attribute must be integer constant";
5731 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5732 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
5733 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
5735 <<
"'" << attrName <<
"' attribute must be used along with "
5736 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
5737 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
5744LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
5745 unsigned regionIndex,
5748 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5752 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
5753 StringAttr attrName = argAttr.
getName();
5754 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5758 <<
"' attribute must be present only on kernel arguments";
5760 if (!isa<UnitAttr>(argAttr.
getValue()))
5761 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
5762 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5765 <<
"' attribute requires the argument to also have attribute '"
5766 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
5777unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
5778 return static_cast<unsigned>(getValue());
5781bool NVVMMemorySpaceAttr::isValidLoad(
5782 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5783 const ::mlir::DataLayout *dataLayout,
5789bool NVVMMemorySpaceAttr::isValidStore(
5790 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5791 const ::mlir::DataLayout *dataLayout,
5797bool NVVMMemorySpaceAttr::isValidAtomicOp(
5798 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
5799 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5802 assert(
false &&
"unimplemented, see TODO in the source.");
5806bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5807 Type type, ptr::AtomicOrdering successOrdering,
5808 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5809 const ::mlir::DataLayout *dataLayout,
5812 assert(
false &&
"unimplemented, see TODO in the source.");
5816bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5820 assert(
false &&
"unimplemented, see TODO in the source.");
5824bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5829 assert(
false &&
"unimplemented, see TODO in the source.");
5838 int optLevel, StringRef triple, StringRef chip,
5839 StringRef features, DictionaryAttr flags,
5841 if (optLevel < 0 || optLevel > 3) {
5842 emitError() <<
"The optimization level must be a number between 0 and 3.";
5845 if (triple.empty()) {
5846 emitError() <<
"The target triple cannot be empty.";
5850 emitError() <<
"The target chip cannot be empty.";
5853 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5854 return mlir::isa_and_nonnull<StringAttr>(attr);
5856 emitError() <<
"All the elements in the `link` array must be strings.";
5862LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
5863 if (!getVerifyTarget())
5866 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5869 "NVVM target attribute must be attached to a GPU module");
5872 const NVVMCheckSMVersion targetSMVersion =
5876 "Minimum NVVM target SM version is sm_20");
5880 ->
walk([&](Operation *op) {
5881 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5882 const NVVMCheckSMVersion requirement =
5883 reqOp.getRequiredMinSMVersion();
5885 op->
emitOpError() <<
"is not supported on " << getChip();
5897#define GET_OP_CLASSES
5898#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5900#define GET_ATTRDEF_CLASSES
5901#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.