31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
305LogicalResult MBarrierExpectTxOp::verify() {
309LogicalResult MBarrierCompleteTxOp::verify() {
313LogicalResult MBarrierTestWaitOp::verify() {
317LogicalResult MBarrierTryWaitOp::verify() {
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
326 return emitError(
"Relu not supported with rna rounding mode.");
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
343 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 <<
" types are supported for conversions from f32x2 to f6x2.";
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
359 bool hasRelu = getRelu();
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
366 if (!isRoundingModeRN) {
367 return emitOpError(
"Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
370 << mlir::Float8E5M2Type::get(ctx) <<
" types";
373 return emitOpError(
"Only SATFINITE saturation mode is supported "
376 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
377 << mlir::Float8E5M2Type::get(ctx) <<
" types";
381 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError(
"Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
388 return emitOpError(
"relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
395 << mlir::Float8E4M3FNType::get(ctx) <<
", "
396 << mlir::Float8E5M2Type::get(ctx) <<
", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
399 "supported for conversions from f32x2 to f8x2";
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
408 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
409 << mlir::Float8E5M2Type::get(ctx)
410 <<
" types are supported for conversions from f16x2 to f8x2.";
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
420 <<
" type is supported for conversions from "
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError(
"Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
436 << mlir::Float4E2M1FNType::get(ctx)
437 <<
" type is supported for conversions from f32x2 to f4x2.";
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
447 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
448 << mlir::Float8E5M2Type::get(ctx)
449 <<
" types are supported for conversions from f8x2 to f16x2.";
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
458 << mlir::Float8E8M0FNUType::get(ctx)
459 <<
" type is supported for conversions from f8x2 to bf16x2.";
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
469 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 <<
" types are supported for conversions from f6x2 to f16x2.";
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
481 << mlir::Float4E2M1FNType::get(ctx)
482 <<
" type is supported for conversions from f4x2 to f16x2.";
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi =
static_cast<bool>(getHi());
496 return emitError(
"mode '") << getMode() <<
"' requires 'hi' operand.";
504 << getMode() <<
"' does not accept 'hi' operand.";
519 static constexpr FPRoundingMode validRndModes[] = {
520 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
522 if (!llvm::is_contained(validRndModes, rnd)) {
524 "Only RN, RZ, and RS rounding modes are supported for "
525 "conversions from f32x2 to ")
529 if (rnd == FPRoundingMode::RS) {
530 if (!hasRandomBits) {
531 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
536 "random_bits not supported for RN and RZ rounding modes.");
543LogicalResult ConvertF32x2ToF16x2Op::verify() {
545 getRandomBits() ?
true :
false, *
this);
548LogicalResult ConvertF32x2ToBF16x2Op::verify() {
550 getRandomBits() ?
true :
false, *
this);
553LogicalResult ConvertF32x4ToF8x4Op::verify() {
556 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
558 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
559 << mlir::Float8E5M2Type::get(ctx)
560 <<
" types are supported for conversions from f32x4 to f8x4.";
565LogicalResult ConvertF32x4ToF6x4Op::verify() {
568 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
570 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
571 << mlir::Float6E3M2FNType::get(ctx)
572 <<
" types are supported for conversions from f32x4 to f6x4.";
577LogicalResult ConvertF32x4ToF4x4Op::verify() {
580 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
581 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
582 <<
" type is supported for conversions from "
588LogicalResult BulkStoreOp::verify() {
589 if (getInitVal() != 0)
590 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
594LogicalResult PMEventOp::verify() {
595 auto eventId = getEventId();
596 auto maskedEventId = getMaskedEventId();
597 if (!maskedEventId && !eventId) {
598 return emitOpError() <<
"either `id` or `mask` must be set";
601 if (maskedEventId && eventId) {
602 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
606 if (eventId < 0 || eventId > 15) {
607 return emitOpError() <<
"`id` must be between 0 and 15";
611 return llvm::success();
617std::optional<mlir::NVVM::MMATypes>
618MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
620 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
621 if (operandElType.
isF64())
622 return NVVM::MMATypes::f64;
623 if (operandElType.
isF16() || operandElType == half2Type)
624 return NVVM::MMATypes::f16;
625 if (operandElType.
isF32() && isAccumulator)
626 return NVVM::MMATypes::f32;
627 if (operandElType.
isF32() && !isAccumulator)
628 return NVVM::MMATypes::tf32;
629 if (llvm::isa<IntegerType>(operandElType)) {
631 return NVVM::MMATypes::s32;
635 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
636 if (structType.getBody().empty())
638 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
645 return (type == MMATypes::u4 || type == MMATypes::s4);
649 return (type == MMATypes::u8 || type == MMATypes::s8);
654 type == MMATypes::s32;
657MMATypes MmaOp::accumPtxType() {
658 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
659 getODSOperands(2).getTypes().front(),
true);
660 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
664MMATypes MmaOp::resultPtxType() {
665 std::optional<mlir::NVVM::MMATypes> val =
666 inferOperandMMAType(getResult().
getType(),
true);
667 assert(val.has_value() &&
"result PTX type should always be inferrable");
673 struct MMAOperandFragment {
674 StringRef operandName;
675 StringRef ptxTypeAttr;
676 SmallVector<Value, 4> regs;
677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
681 std::array<MMAOperandFragment, 3> frags{
682 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
683 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
684 MMAOperandFragment(
"C",
"")};
686 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
688 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
689 auto &frag = frags[fragIdx];
690 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
691 for (
auto operandIdx = varOperandSpec.first;
692 operandIdx < varOperandSpec.first + varOperandSpec.second;
694 frag.regs.push_back(this->getOperand(operandIdx));
695 if (operandIdx == 0) {
696 regTypes.push_back(this->getOperand(operandIdx).
getType());
699 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
700 regTypes.back(), fragIdx >= 2);
702 ignoreAttrNames.push_back(frag.ptxTypeAttr);
705 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
706 p <<
" " << frag.operandName;
712 for (
const auto &frag : frags) {
713 printMmaOperand(frag);
722 frags[1].regs[0].getType(),
723 frags[2].regs[0].getType()},
732 std::optional<MMAIntOverflow> intOverflow,
733 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
734 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
736 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
741 result.addOperands(operandA);
742 result.addOperands(operandB);
743 result.addOperands(operandC);
745 if (multiplicandPtxTypes) {
746 result.addAttribute(
"multiplicandAPtxType",
747 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
748 result.addAttribute(
"multiplicandBPtxType",
749 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
751 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
752 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
753 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
754 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
757 if (multiplicandLayouts) {
758 result.addAttribute(
"layoutA",
759 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
760 result.addAttribute(
"layoutB",
761 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
763 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
764 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
767 if (intOverflow.has_value())
768 result.addAttribute(
"intOverflowBehavior",
769 MMAIntOverflowAttr::get(ctx, *intOverflow));
770 if (b1Op.has_value())
771 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
773 result.addTypes(resultType);
775 MmaOp::getOperandSegmentSizeAttr(),
777 static_cast<int32_t>(operandB.size()),
778 static_cast<int32_t>(operandC.size())}));
786 struct MMAOperandFragment {
787 std::optional<MMATypes> elemtype;
788 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
789 SmallVector<Type> regTypes;
793 std::array<MMAOperandFragment, 4> frags;
799 MMAOperandFragment &frag) -> LogicalResult {
829 if (operandTypes.size() != 3)
832 "expected one type for each operand segment but got " +
833 Twine(operandTypes.size()) +
" types");
834 for (
const auto &iter : llvm::enumerate(operandTypes)) {
835 auto &frag = frags[iter.index()];
836 frag.regTypes.resize(frag.regs.size(), iter.value());
840 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
847 frags[3].elemtype = inferOperandMMAType(resultType,
true);
849 std::array<StringRef, 2> names{
"multiplicandAPtxType",
850 "multiplicandBPtxType"};
851 for (
unsigned idx = 0; idx < names.size(); idx++) {
852 const auto &frag = frags[idx];
853 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
854 if (!frag.elemtype.has_value() && !attr.has_value()) {
857 "attribute " + names[idx] +
858 " is not provided explicitly and cannot be inferred");
860 if (!attr.has_value())
862 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
865 result.addTypes(resultType);
866 if (!namedAttributes.
empty())
867 result.addAttributes(namedAttributes);
868 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
870 static_cast<int32_t>(frags[0].regs.size()),
871 static_cast<int32_t>(frags[1].regs.size()),
872 static_cast<int32_t>(frags[2].regs.size()),
877LogicalResult MmaOp::verify() {
879 auto f16Ty = Float16Type::get(context);
880 auto i32Ty = IntegerType::get(context, 32);
881 auto f16x2Ty = VectorType::get(2, f16Ty);
882 auto f32Ty = Float32Type::get(context);
883 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
884 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
887 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
890 auto f16x2x2StructTy =
891 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
893 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
895 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
897 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
898 getShapeAttr().getK()};
904 AllowedShapes allowedShapes;
905 AllowedTypes expectedA;
906 AllowedTypes expectedB;
907 AllowedTypes expectedC;
912 if (mmaShape[0] == 16) {
914 Type multiplicandFragType;
915 switch (*getMultiplicandAPtxType()) {
918 multiplicandFragType = i32Ty;
919 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
920 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
924 multiplicandFragType = i32Ty;
925 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
926 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
930 multiplicandFragType = f16x2Ty;
931 expectedResult.push_back(f16x2x2StructTy);
932 expectedResult.push_back(f32x4StructTy);
946 return emitError(
"invalid shape or multiplicand type: ")
947 << getMultiplicandAPtxType().value();
951 expectedResult.push_back(s32x4StructTy);
952 expectedC.emplace_back(4, i32Ty);
953 multiplicandFragType = i32Ty;
955 expectedC.emplace_back(2, f16x2Ty);
956 expectedC.emplace_back(4, f32Ty);
959 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
960 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
961 expectedA.emplace_back(unitA, multiplicandFragType);
962 expectedB.emplace_back(unitB, multiplicandFragType);
963 allowedShapes.push_back({16, 8, kFactor});
964 allowedShapes.push_back({16, 8, kFactor * 2});
966 if (resultPtxType() != accumPtxType())
971 if (mmaShape[0] == 8) {
972 if (*getMultiplicandAPtxType() == MMATypes::f16) {
973 expectedA.emplace_back(2, f16x2Ty);
974 expectedB.emplace_back(2, f16x2Ty);
975 expectedResult.push_back(f16x2x4StructTy);
976 expectedResult.push_back(f32x8StructTy);
977 expectedC.emplace_back(4, f16x2Ty);
978 expectedC.emplace_back(8, f32Ty);
979 allowedShapes.push_back({8, 8, 4});
981 if (*getMultiplicandAPtxType() == MMATypes::f64) {
982 Type f64Ty = Float64Type::get(context);
983 expectedA.emplace_back(1, f64Ty);
984 expectedB.emplace_back(1, f64Ty);
985 expectedC.emplace_back(2, f64Ty);
986 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
988 allowedShapes.push_back({8, 8, 4});
991 expectedA.push_back({i32Ty});
992 expectedB.push_back({i32Ty});
993 expectedC.push_back({i32Ty, i32Ty});
994 expectedResult.push_back(s32x2StructTy);
996 allowedShapes.push_back({8, 8, 32});
998 allowedShapes.push_back({8, 8, 16});
999 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1000 allowedShapes.push_back({8, 8, 128});
1004 std::string errorMessage;
1005 llvm::raw_string_ostream errorStream(errorMessage);
1008 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1009 !llvm::is_contained(allowedShapes, mmaShape)) {
1010 errorStream <<
"unimplemented variant for MMA shape <";
1011 llvm::interleaveComma(mmaShape, errorStream);
1017 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1018 for (
const auto &iter : llvm::enumerate(
1020 auto spec = this->getODSOperandIndexAndLength(iter.index());
1022 operand_type_begin() + spec.first +
1024 bool match = llvm::is_contained(iter.value(), operandTySeg);
1027 errorStream <<
"Could not match types for the "
1028 << operandNames[iter.index()]
1029 <<
" operands; expected one of ";
1030 for (
const auto &x : iter.value()) {
1031 errorStream << x.size() <<
"x" << x[0] <<
" ";
1033 errorStream <<
"but got ";
1034 llvm::interleaveComma(operandTySeg, errorStream);
1040 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1041 return expectedResultType == getResult().getType();
1044 <<
"Could not match allowed types for the result; expected one of ";
1045 llvm::interleaveComma(expectedResult, errorStream);
1046 errorStream <<
" but got " << getResult().getType();
1051 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1052 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1060 if (!getIntOverflowBehavior())
1062 getIntOverflowBehaviorAttrName().strref() +
1070 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1071 getMultiplicandAPtxType() == MMATypes::f16);
1073 if (!isM8N8K4_F16) {
1075 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1076 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1077 "layoutB = #nvvm.mma_layout<col> for shape <")
1078 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1079 <<
"> with element types " << *getMultiplicandAPtxType() <<
" and "
1080 << *getMultiplicandBPtxType()
1081 <<
". Only m8n8k4 with f16 supports other layouts.";
1088MMATypes MmaSpOp::accumPtxType() {
1089 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1090 getODSOperands(2).getTypes().front(),
true);
1091 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1095MMATypes MmaSpOp::resultPtxType() {
1096 std::optional<mlir::NVVM::MMATypes> val =
1097 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1098 assert(val.has_value() &&
"result PTX type should always be inferrable");
1104 llvm::IRBuilderBase &builder) {
1105 auto thisOp = cast<NVVM::MmaSpOp>(op);
1113 auto intId = MmaSpOp::getIntrinsicID(
1114 thisOp.getShape().getM(), thisOp.getShape().getN(),
1115 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1116 thisOp.getOrderedMetadata(), thisOp.getKind(),
1117 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1118 thisOp.accumPtxType(), thisOp.resultPtxType());
1120 return {intId, args};
1125 struct MMAOperandFragment {
1126 StringRef operandName;
1127 StringRef ptxTypeAttr;
1128 SmallVector<Value, 4> regs;
1129 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1130 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1133 std::array<MMAOperandFragment, 5> frags{
1134 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1135 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1136 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1137 MMAOperandFragment(
"selector",
"")};
1139 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1142 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1143 auto &frag = frags[fragIdx];
1144 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1145 for (
auto operandIdx = varOperandSpec.first;
1146 operandIdx < varOperandSpec.first + varOperandSpec.second;
1148 frag.regs.push_back(this->getOperand(operandIdx));
1149 if (operandIdx == varOperandSpec.first) {
1150 regTypes.push_back(this->getOperand(operandIdx).
getType());
1153 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1154 regTypes.back(), fragIdx >= 2);
1156 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1160 frags[3].regs.push_back(getSparseMetadata());
1161 frags[4].regs.push_back(getSparsitySelector());
1163 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1164 p <<
" " << frag.operandName;
1170 for (
const auto &frag : frags)
1171 printMmaSpOperand(frag);
1176 for (
int i = 0; i < 3; ++i) {
1181 p <<
") -> " << getResult().getType();
1188 std::optional<MMAIntOverflow> intOverflow,
1189 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1191 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1196 result.addOperands(operandA);
1197 result.addOperands(operandB);
1198 result.addOperands(operandC);
1199 result.addOperands(sparseMetadata);
1200 result.addOperands(sparsitySelector);
1202 if (multiplicandPtxTypes) {
1203 result.addAttribute(
"multiplicandAPtxType",
1204 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1205 result.addAttribute(
"multiplicandBPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1208 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1209 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1210 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1211 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1214 if (intOverflow.has_value())
1215 result.addAttribute(
"intOverflowBehavior",
1216 MMAIntOverflowAttr::get(ctx, *intOverflow));
1218 result.addTypes(resultType);
1220 MmaSpOp::getOperandSegmentSizeAttr(),
1222 static_cast<int32_t>(operandB.size()),
1223 static_cast<int32_t>(operandC.size()), 1,
1228 struct MMAOperandFragment {
1229 std::optional<MMATypes> elemtype;
1230 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1231 SmallVector<Type> regTypes;
1235 std::array<MMAOperandFragment, 6> frags;
1240 auto parseMmaSpOperand = [&](StringRef operandName,
1241 MMAOperandFragment &frag) -> LogicalResult {
1252 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1254 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1256 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1258 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1260 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1276 if (operandTypes.size() != 3)
1279 "expected one type for each operand segment but got " +
1280 Twine(operandTypes.size()) +
" types");
1281 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1282 auto &frag = frags[iter.index()];
1283 frag.regTypes.resize(frag.regs.size(), iter.value());
1288 MmaOp::inferOperandMMAType(frag.regTypes[0],
1296 MmaOp::inferOperandMMAType(resultType,
true);
1311 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1312 "multiplicandBPtxType"};
1313 for (
unsigned idx = 0; idx < names.size(); idx++) {
1314 const auto &frag = frags[idx];
1315 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1316 if (!frag.elemtype.has_value() && !attr.has_value()) {
1319 "attribute " + names[idx] +
1320 " is not provided explicitly and cannot be inferred");
1322 if (!attr.has_value())
1324 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1327 result.addTypes(resultType);
1328 if (!namedAttributes.
empty())
1329 result.addAttributes(namedAttributes);
1330 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1332 static_cast<int32_t>(frags[0].regs.size()),
1333 static_cast<int32_t>(frags[1].regs.size()),
1334 static_cast<int32_t>(frags[2].regs.size()),
1341LogicalResult MmaSpOp::verify() {
1343 auto f16Ty = Float16Type::get(context);
1344 auto i32Ty = IntegerType::get(context, 32);
1345 auto f16x2Ty = VectorType::get(2, f16Ty);
1346 auto f32Ty = Float32Type::get(context);
1347 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1348 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1350 auto s32x4StructTy =
1351 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1352 auto f32x8StructTy =
1354 auto f16x2x2StructTy =
1355 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1356 auto f32x4StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1358 auto s32x2StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1361 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1362 getShapeAttr().getK()};
1368 AllowedShapes allowedShapes;
1369 AllowedTypes expectedA;
1370 AllowedTypes expectedB;
1371 AllowedTypes expectedC;
1376 if (mmaShape[0] == 16) {
1378 Type multiplicandFragType;
1379 switch (*getMultiplicandAPtxType()) {
1380 case MMATypes::tf32:
1382 multiplicandFragType = i32Ty;
1383 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1384 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1386 allowedShapes.push_back({16, 8, 8});
1387 allowedShapes.push_back({16, 8, 16});
1389 case MMATypes::bf16:
1391 multiplicandFragType = i32Ty;
1392 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1393 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1395 allowedShapes.push_back({16, 8, 16});
1396 allowedShapes.push_back({16, 8, 32});
1400 multiplicandFragType = f16x2Ty;
1401 expectedResult.push_back(f16x2x2StructTy);
1402 expectedResult.push_back(f32x4StructTy);
1404 allowedShapes.push_back({16, 8, 16});
1405 allowedShapes.push_back({16, 8, 32});
1411 allowedShapes.push_back({16, 8, 64});
1412 allowedShapes.push_back({16, 8, 128});
1418 allowedShapes.push_back({16, 8, 32});
1419 allowedShapes.push_back({16, 8, 64});
1421 case MMATypes::e4m3:
1422 case MMATypes::e5m2:
1423 case MMATypes::e3m2:
1424 case MMATypes::e2m3:
1425 case MMATypes::e2m1:
1427 multiplicandFragType = i32Ty;
1428 expectedResult.push_back(f16x2x2StructTy);
1429 expectedResult.push_back(f32x4StructTy);
1431 allowedShapes.push_back({16, 8, 64});
1434 return emitError(
"invalid shape or multiplicand type: ")
1435 << getMultiplicandAPtxType().value();
1439 expectedResult.push_back(s32x4StructTy);
1440 expectedC.emplace_back(4, i32Ty);
1441 multiplicandFragType = i32Ty;
1442 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1443 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1445 expectedC.emplace_back(2, f16x2Ty);
1446 expectedC.emplace_back(4, f32Ty);
1448 expectedC.emplace_back(2, f16x2Ty);
1449 expectedC.emplace_back(4, f32Ty);
1454 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1455 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1456 expectedA.emplace_back(unitA, multiplicandFragType);
1457 expectedB.emplace_back(unitB, multiplicandFragType);
1459 if (resultPtxType() != accumPtxType())
1464 if (mmaShape[0] == 8) {
1465 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1466 expectedA.emplace_back(2, f16x2Ty);
1467 expectedB.emplace_back(2, f16x2Ty);
1468 expectedResult.push_back(f16x2x4StructTy);
1469 expectedResult.push_back(f32x8StructTy);
1470 expectedC.emplace_back(4, f16x2Ty);
1471 expectedC.emplace_back(8, f32Ty);
1472 allowedShapes.push_back({8, 8, 4});
1474 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1475 Type f64Ty = Float64Type::get(context);
1476 expectedA.emplace_back(1, f64Ty);
1477 expectedB.emplace_back(1, f64Ty);
1478 expectedC.emplace_back(2, f64Ty);
1479 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1481 allowedShapes.push_back({8, 8, 4});
1484 expectedA.push_back({i32Ty});
1485 expectedB.push_back({i32Ty});
1486 expectedC.push_back({i32Ty, i32Ty});
1487 expectedResult.push_back(s32x2StructTy);
1489 allowedShapes.push_back({8, 8, 32});
1491 allowedShapes.push_back({8, 8, 16});
1495 std::string errorMessage;
1496 llvm::raw_string_ostream errorStream(errorMessage);
1499 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1500 !llvm::is_contained(allowedShapes, mmaShape)) {
1501 errorStream <<
"unimplemented variant for MMA shape <";
1502 llvm::interleaveComma(mmaShape, errorStream);
1508 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1509 for (
const auto &iter : llvm::enumerate(
1511 auto spec = this->getODSOperandIndexAndLength(iter.index());
1513 operand_type_begin() + spec.first +
1515 bool match = llvm::is_contained(iter.value(), operandTySeg);
1518 errorStream <<
"Could not match types for the "
1519 << operandNames[iter.index()]
1520 <<
" operands; expected one of ";
1521 for (
const auto &x : iter.value()) {
1522 errorStream << x.size() <<
"x" << x[0] <<
" ";
1524 errorStream <<
"but got ";
1525 llvm::interleaveComma(operandTySeg, errorStream);
1531 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1532 return expectedResultType == getResult().getType();
1535 <<
"Could not match allowed types for the result; expected one of ";
1536 llvm::interleaveComma(expectedResult, errorStream);
1537 errorStream <<
" but got " << getResult().getType();
1545 if (!getIntOverflowBehavior())
1547 getIntOverflowBehaviorAttrName().strref() +
1552 if (!getSparseMetadata().
getType().isInteger(32)) {
1553 return emitOpError() <<
"sparse metadata must be i32 type";
1557 if (!getSparsitySelector().
getType().isInteger(32)) {
1558 return emitOpError() <<
"sparsity selector must be i32 type";
1570struct MMAOperandFragment {
1571 StringRef operandName;
1572 StringRef ptxTypeAttr;
1573 SmallVector<Value, 4> regs;
1574 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1575 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1582 p <<
" " << name <<
"[";
1601template <
typename Op>
1606 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1607 auto &frag = frags[fragIdx];
1608 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1609 for (
auto operandIdx = varOperandSpec.first;
1610 operandIdx < varOperandSpec.first + varOperandSpec.second;
1612 frag.regs.push_back(op.getOperand(operandIdx));
1613 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1614 regTypes.push_back(op.getOperand(operandIdx).getType());
1618 regTypes.push_back(frag.regs[0].getType());
1620 std::optional<MMATypes> inferredType =
1621 MmaOp::inferOperandMMAType(regTypes.back(),
1624 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1635 auto typeParser = [&]() {
1639 operandTypes.push_back(ty);
1645 if (operandTypes.size() != 3)
1647 "expected exactly 3 types");
1656 if (!attrs.
get(
"multiplicandAPtxType")) {
1657 if (
auto inferredType =
1658 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1659 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1662 if (!attrs.
get(
"multiplicandBPtxType")) {
1663 if (
auto inferredType =
1664 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1665 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1671template <
typename OpType>
1674 ScaleVecSize scaleVecSize,
1675 BlockScaleFormat blockScaleFormat,
1676 MMABlockScaleKind kind) {
1678 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1679 properties.setShape(
1681 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1682 properties.setBlockScaleFormat(
1683 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1684 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1691 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1692 if (multiplicandPtxTypes) {
1693 result.addAttribute(
"multiplicandAPtxType",
1694 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1695 result.addAttribute(
"multiplicandBPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1698 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1699 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1700 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1701 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1706template <
typename OpTy>
1708 return *MmaOp::inferOperandMMAType(
1709 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1719 std::array<MMAOperandFragment, 3> frags{
1720 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1721 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1722 MMAOperandFragment(
"C",
"")};
1724 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1729 for (
const auto &frag : frags)
1734 {getScaleAData(), getByteIdA(), getThreadIdA()});
1736 {getScaleBData(), getByteIdB(), getThreadIdB()});
1743 frags[1].regs[0].getType(),
1744 frags[2].regs[0].getType()},
1750ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1752 struct LocalOperandFragment {
1753 std::optional<MMATypes> elemtype;
1754 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1758 std::array<LocalOperandFragment, 3> frags;
1787 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1788 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1791 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1801 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1811 result.addAttributes(namedAttributes);
1815 result.addTypes(resultTypes);
1816 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1818 static_cast<int32_t>(frags[0].regs.size()),
1819 static_cast<int32_t>(frags[1].regs.size()),
1820 static_cast<int32_t>(frags[2].regs.size()),
1831void MmaBlockScaleOp::build(
1836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1837 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1838 MMABlockScaleKind kind) {
1839 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1842 blockScaleFormat, kind);
1844 result.addOperands(operandA);
1845 result.addOperands(operandB);
1846 result.addOperands(operandC);
1848 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1851 multiplicandPtxTypes);
1853 result.addTypes(resultType);
1854 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1856 static_cast<int32_t>(operandA.size()),
1857 static_cast<int32_t>(operandB.size()),
1858 static_cast<int32_t>(operandC.size()),
1870 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1874 for (
Value operand : curOp.getOperandA())
1876 for (
Value operand : curOp.getOperandB())
1878 for (
Value operand : curOp.getOperandC())
1882 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1883 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1884 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1885 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1886 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1887 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1889 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1890 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1891 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1893 curOp.getBlockScaleFormat(), curOp.getKind());
1895 return {intId, args};
1898LogicalResult MmaBlockScaleOp::verify() {
1904 if (m == 16 && n == 8 && k == 64) {
1905 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1906 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1908 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1909 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1910 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1912 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1913 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1915 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1916 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1917 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1918 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1919 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1920 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
1922 "attributes for mma.m16n8k64.mxf4nvf4");
1926 }
else if (m == 16 && n == 8 && k == 32) {
1927 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1928 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1929 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1931 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
1932 "attributes for mma.m16n8k32");
1945 std::array<MMAOperandFragment, 3> frags{
1946 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1947 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1948 MMAOperandFragment(
"C",
"")};
1950 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1955 for (
const auto &frag : frags)
1964 {getScaleAData(), getByteIdA(), getThreadIdA()});
1966 {getScaleBData(), getByteIdB(), getThreadIdB()});
1973 frags[1].regs[0].getType(),
1974 frags[2].regs[0].getType()},
1980ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
1982 struct LocalOperandFragment {
1983 std::optional<MMATypes> elemtype;
1984 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1988 std::array<LocalOperandFragment, 3> frags;
2024 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2025 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2028 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2037 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2050 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2060 result.addAttributes(namedAttributes);
2065 if (!
result.attributes.get(
"orderedMetadata"))
2068 result.addTypes(resultTypes);
2069 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2071 static_cast<int32_t>(frags[0].regs.size()),
2072 static_cast<int32_t>(frags[1].regs.size()),
2073 static_cast<int32_t>(frags[2].regs.size()),
2086void MmaSpBlockScaleOp::build(
2092 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2093 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2094 MMABlockScaleKind kind) {
2095 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2098 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2101 result.addOperands(operandA);
2102 result.addOperands(operandB);
2103 result.addOperands(operandC);
2104 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2105 threadIdA, scaleBData, byteIdB, threadIdB});
2108 multiplicandPtxTypes);
2110 result.addTypes(resultType);
2111 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2113 static_cast<int32_t>(operandA.size()),
2114 static_cast<int32_t>(operandB.size()),
2115 static_cast<int32_t>(operandC.size()),
2129 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2133 for (
Value operand : curOp.getOperandA())
2135 for (
Value operand : curOp.getOperandB())
2137 for (
Value operand : curOp.getOperandC())
2141 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2142 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2145 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2146 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2147 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2148 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2149 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2150 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2152 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2153 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2154 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2156 curOp.getBlockScaleFormat(), curOp.getKind());
2158 return {intId, args};
2161LogicalResult MmaSpBlockScaleOp::verify() {
2163 if (!getOrderedMetadata()) {
2164 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2172 if (m == 16 && n == 8 && k == 128) {
2173 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2174 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2176 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2177 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2178 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2180 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2181 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2183 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2184 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2185 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2186 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2187 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2188 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
2190 "attributes for mma.m16n8k128.mxf4nvf4");
2194 }
else if (m == 16 && n == 8 && k == 64) {
2195 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2196 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2197 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2199 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2200 "attributes for mma.m16n8k64");
2207LogicalResult ShflOp::verify() {
2208 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2210 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2211 Type actualType) -> LogicalResult {
2212 return emitOpError(
"expected " + desc +
" to be of type ")
2213 << expectedType <<
" but got " << actualType <<
" instead";
2216 if (returnStructType) {
2217 if (!getReturnValueAndIsValid())
2218 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2219 "specified when the return type is a struct type");
2221 if (returnStructType.getBody().size() != 2)
2222 return emitOpError(
"expected return type to be a two-element struct");
2225 auto resultType = returnStruct[0];
2226 if (resultType != getVal().
getType())
2227 return verifyTypeError(
"first element in the returned struct",
2228 getVal().
getType(), resultType);
2230 auto predicateType = returnStruct[1];
2231 if (!predicateType.isInteger(1))
2232 return verifyTypeError(
"second element in the returned struct",
2236 if (getReturnValueAndIsValid())
2237 return emitOpError(
"expected return type to be a two-element struct");
2240 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2246 NVVM::MMAFrag frag,
int nRow,
2249 unsigned numberElements = 0;
2252 Type f16x2 = VectorType::get(2, builder.getF16Type());
2253 if (type == NVVM::MMATypes::f16) {
2254 elementType = f16x2;
2255 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2259 }
else if (type == NVVM::MMATypes::f32) {
2260 elementType = builder.getF32Type();
2262 }
else if (type == NVVM::MMATypes::f64) {
2263 elementType = builder.getF64Type();
2264 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2268 }
else if (type == NVVM::MMATypes::tf32) {
2269 elementType = builder.getI32Type();
2271 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2272 elementType = builder.getI32Type();
2273 int parallelSize = 0;
2274 if (frag == NVVM::MMAFrag::a)
2275 parallelSize = nRow;
2276 if (frag == NVVM::MMAFrag::b)
2277 parallelSize = nCol;
2280 if (parallelSize == 16)
2283 else if (parallelSize == 8)
2285 else if (parallelSize == 32)
2287 }
else if (type == NVVM::MMATypes::s32) {
2288 elementType = builder.getI32Type();
2291 assert(numberElements != 0 && elementType !=
nullptr);
2292 return std::make_pair(elementType, numberElements);
2295static std::pair<mlir::Type, unsigned>
2299 if (frag == NVVM::MMAFrag::a) {
2302 }
else if (frag == NVVM::MMAFrag::b) {
2309 assert(nRow && nCol);
2313LogicalResult NVVM::WMMALoadOp::verify() {
2314 unsigned addressSpace =
2315 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2316 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2317 addressSpace != NVVMMemorySpace::Shared)
2318 return emitOpError(
"expected source pointer in memory "
2321 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2322 getEltype(), getFrag()) == 0)
2323 return emitOpError() <<
"invalid attribute combination";
2328 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2330 return emitOpError(
"expected destination type to be f64");
2334 Type dstType = LLVM::LLVMStructType::getLiteral(
2337 return emitOpError(
"expected destination type is a structure of ")
2338 << typeInfo.second <<
" elements of type " << typeInfo.first;
2342LogicalResult NVVM::WMMAStoreOp::verify() {
2343 unsigned addressSpace =
2344 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2345 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2346 addressSpace != NVVMMemorySpace::Shared)
2347 return emitOpError(
"expected operands to be a source pointer in memory "
2350 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2352 return emitOpError() <<
"invalid attribute combination";
2355 if (getArgs().size() != typeInfo.second)
2356 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2357 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2358 return operands.
getType() != typeInfo.first;
2360 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2364LogicalResult NVVM::WMMAMmaOp::verify() {
2365 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2366 getLayoutB(), getEltypeA(),
2368 return emitOpError() <<
"invalid attribute combination";
2376 arguments.append(typeInfoA.second, typeInfoA.first);
2377 arguments.append(typeInfoB.second, typeInfoB.first);
2378 arguments.append(typeInfoC.second, typeInfoC.first);
2379 unsigned numArgs = arguments.size();
2380 if (getArgs().size() != numArgs)
2381 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2382 for (
unsigned i = 0; i < numArgs; i++) {
2383 if (getArgs()[i].
getType() != arguments[i])
2384 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2387 Type dstType = LLVM::LLVMStructType::getLiteral(
2390 return emitOpError(
"expected destination type is a structure of ")
2391 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2395LogicalResult NVVM::LdMatrixOp::verify() {
2397 if (m == 8 && n == 8) {
2398 if (num != 1 && num != 2 && num != 4) {
2399 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2402 if (getEltType() != LdStMatrixEltType::B16) {
2403 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2405 }
else if (m == 8 && n == 16) {
2406 if (num != 1 && num != 2 && num != 4) {
2407 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2410 if (getLayout() != MMALayout::row) {
2411 return emitOpError(
"expected layout to be row for 8x16 matrix");
2413 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2414 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2415 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2416 "b8x16.b6x16_p32 for 8x16 matrix");
2418 }
else if (m == 16 && n == 16) {
2419 if (num != 1 && num != 2) {
2420 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2423 if (getLayout() != MMALayout::col) {
2424 return emitOpError(
"expected layout to be col for 16x16 matrix");
2426 if (getEltType() != LdStMatrixEltType::B8 &&
2427 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2428 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2429 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2430 "b8x16.b6x16_p32 for 16x16 matrix");
2433 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2437 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2438 if (numElements == 1 &&
getType() != i32)
2439 return emitOpError(
"expected destination type is i32");
2440 if (numElements == 2 || numElements == 4) {
2441 Type dstType = LLVM::LLVMStructType::getLiteral(
2444 return emitOpError(
"expected destination type is a structure of ")
2445 << numElements <<
" elements of type i32";
2451LogicalResult NVVM::StMatrixOp::verify() {
2452 int numMatrix = getSources().size();
2453 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2454 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2457 if (m == 8 && n == 8) {
2458 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2459 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2461 }
else if (m == 16 && n == 8) {
2462 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2463 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2465 if (getLayout() != NVVM::MMALayout::col) {
2466 return emitOpError(
"expected layout to be col for 16x8 matrix");
2469 return emitOpError(
"expected shape to be 8x8 or 16x8");
2476 if (typeA == NVVM::WGMMATypes::tf32)
2478 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2480 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2482 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2484 if (typeA == NVVM::WGMMATypes::b1)
2490 NVVM::WGMMATypes typeA,
2491 NVVM::WGMMATypes typeB) {
2493 case NVVM::WGMMATypes::f16:
2494 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2495 typeB == NVVM::WGMMATypes::f16)
2498 case NVVM::WGMMATypes::tf32:
2499 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2502 case NVVM::WGMMATypes::u8:
2503 case NVVM::WGMMATypes::s8:
2504 if (typeD == NVVM::WGMMATypes::s32 &&
2505 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2508 case NVVM::WGMMATypes::b1:
2509 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2512 case NVVM::WGMMATypes::bf16:
2513 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2514 typeB == NVVM::WGMMATypes::bf16)
2517 case NVVM::WGMMATypes::e4m3:
2518 case NVVM::WGMMATypes::e5m2:
2519 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2520 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2523 case WGMMATypes::f32:
2524 case WGMMATypes::s32:
2525 llvm_unreachable(
"unsupported input types");
2533 72, 80, 88, 96, 104, 112, 120, 128,
2534 136, 144, 152, 160, 168, 176, 184, 192,
2535 200, 208, 216, 224, 232, 240, 248, 256};
2537 80, 96, 112, 128, 144, 160,
2538 176, 192, 208, 224, 240, 256};
2540 case WGMMATypes::f16:
2541 case WGMMATypes::tf32:
2542 case WGMMATypes::bf16:
2543 case WGMMATypes::e4m3:
2544 case WGMMATypes::e5m2:
2545 if (llvm::is_contained(allowedN, sizeN))
2548 case WGMMATypes::u8:
2549 case WGMMATypes::s8:
2550 case WGMMATypes::b1:
2551 if (llvm::is_contained(allowedNshort, sizeN))
2554 case WGMMATypes::f32:
2555 case WGMMATypes::s32:
2556 llvm_unreachable(
"unsupported input types");
2562LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2563 Value outValue = getResults();
2564 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2566 return emitOpError() <<
"expected results to be struct";
2567 int outputSize = stype.getBody().size();
2568 WGMMATypes typeD = getTypeD();
2569 WGMMATypes typeA = getTypeA();
2570 WGMMATypes typeB = getTypeB();
2572 for (
Type t : stype.getBody()) {
2573 if (t != stype.getBody().front())
2575 <<
"all elements in struct must be same type but there is " << t;
2578 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2579 typeD != WGMMATypes::s32) {
2580 return emitOpError() <<
"does not support the given output type " << typeD;
2582 if (typeD == WGMMATypes::s32 &&
2583 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2584 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2588 return emitOpError() << typeD <<
" += " << typeA <<
" * " << typeB
2589 <<
", it is not supported.";
2599 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2600 <<
" for input type " << typeA;
2604 return emitOpError() <<
"has input type " << typeA <<
" n is set to "
2605 <<
getShape().getN() <<
", it is not supported.";
2612 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2613 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2614 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2616 <<
"given layouts layout_a = " << getLayoutA()
2617 <<
" and layout_b = " << getLayoutB() <<
" for input types " << typeA
2619 <<
" requires transpose. However, this is only supported for: "
2620 << MMATypes::f16 <<
" and " << MMATypes::bf16;
2624 int expectedOutput = 0;
2625 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2626 expectedOutput =
getShape().getN() / 2;
2627 if (typeD == WGMMATypes::f16)
2628 expectedOutput =
getShape().getN() / 4;
2629 if (outputSize != expectedOutput) {
2630 return emitOpError() <<
"results " << expectedOutput
2631 <<
", however output struct has " << outputSize
2635 if (typeD != WGMMATypes::s32 &&
2636 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2637 NVVM::MMAIntOverflow::satfinite) {
2639 <<
" `satfinite` can be only used with s32 accumulator, however "
2640 "the current accumulator is "
2647std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2650 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2652 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2654 int expectedOutputRegisters = 0;
2655 if (getTypeD() == WGMMATypes::f16)
2656 expectedOutputRegisters =
getShape().getN() / 4;
2658 expectedOutputRegisters =
getShape().getN() / 2;
2661 llvm::raw_string_ostream ss(ptx);
2666 << ((expectedOutputRegisters * 2) + 2)
2668 "wgmma.mma_async.sync.aligned.m"
2669 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"." << getTypeA()
2670 <<
"." << getTypeB();
2671 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2672 NVVM::MMAIntOverflow::satfinite)
2676 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2677 ss <<
"$" << regCnt;
2678 if (regCnt != expectedOutputRegisters - 1)
2684 regCnt = (regCnt * 2);
2685 ss <<
" $" << (regCnt) <<
","
2686 <<
" $" << (regCnt + 1) <<
","
2688 if (getTypeD() != WGMMATypes::s32) {
2689 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2693 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2700bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2704 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2711 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2713 if (getTypeD() != WGMMATypes::s32) {
2714 asmValues.push_back(
2715 {makeConstantI32(rewriter,
2716 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2718 asmValues.push_back(
2719 {makeConstantI32(rewriter,
2720 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2724 asmValues.push_back(
2725 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2727 asmValues.push_back(
2728 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2734LogicalResult NVVM::FenceSyncRestrictOp::verify() {
2735 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2736 getOrder() != NVVM::MemOrderKind::RELEASE)
2737 return emitOpError(
"only acquire and release semantics are supported");
2741LogicalResult NVVM::FenceProxyOp::verify() {
2742 if (getKind() == NVVM::ProxyKind::TENSORMAP)
2743 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
2744 if (getKind() == NVVM::ProxyKind::GENERIC)
2745 return emitOpError() <<
"generic proxy not a supported proxy kind";
2746 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2747 return emitOpError() <<
"async_shared fence requires space attribute";
2749 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2750 return emitOpError() <<
"only async_shared fence can have space attribute";
2755LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2756 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2757 return emitOpError(
"uni-directional proxies only support generic for "
2758 "from_proxy attribute");
2760 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2761 return emitOpError(
"uni-directional proxies only support tensormap "
2762 "for to_proxy attribute");
2766LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2767 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2768 return emitOpError(
"uni-directional proxies only support generic for "
2769 "from_proxy attribute");
2771 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2772 return emitOpError(
"uni-directional proxies only support tensormap "
2773 "for to_proxy attribute");
2777LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2778 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2779 getOrder() != NVVM::MemOrderKind::RELEASE)
2780 return emitOpError(
"only acquire and release semantics are supported");
2782 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2783 return emitOpError(
"only generic is support for from_proxy attribute");
2785 if (getToProxy() != NVVM::ProxyKind::async)
2786 return emitOpError(
"only async is supported for to_proxy attribute");
2790LogicalResult NVVM::SetMaxRegisterOp::verify() {
2791 if (getRegCount() % 8)
2792 return emitOpError(
"new register size must be multiple of 8");
2793 if (getRegCount() < 24 || getRegCount() > 256)
2794 return emitOpError(
"new register size must be in between 24 to 256");
2798LogicalResult NVVM::BarrierOp::verify() {
2799 if (getNumberOfThreads() && !getBarrierId())
2801 "barrier id is missing, it should be set between 0 to 15");
2803 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
2804 return emitOpError(
"reduction are only available when id is 0");
2808 return emitOpError(
"reduction predicate and reduction operation must be "
2809 "specified together");
2814LogicalResult NVVM::Tcgen05CpOp::verify() {
2815 auto mc = getMulticast();
2817 using SH = Tcgen05CpShape;
2818 using MC = Tcgen05CpMulticast;
2820 case SH::SHAPE_128x256b:
2821 case SH::SHAPE_128x128b:
2822 case SH::SHAPE_4x256b:
2824 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2826 case SH::SHAPE_64x128b:
2827 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2828 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2829 "warpx2_02_13 for tcgen05.cp Op");
2831 case SH::SHAPE_32x128b:
2832 if (mc != MC::WARPX4)
2834 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2840LogicalResult NVVM::MatchSyncOp::verify() {
2841 if (getKind() == NVVM::MatchSyncKind::all) {
2842 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2843 if (!type || type.getBody().size() != 2 ||
2844 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2845 return emitOpError(
"match.sync 'all' returns a two element struct with "
2846 "first element as i32 and second element as i1");
2849 if (!
getType().isInteger(32)) {
2850 return emitOpError(
"match.sync 'any' returns an i32");
2856LogicalResult NVVM::VoteSyncOp::verify() {
2857 if (getKind() == NVVM::VoteSyncKind::ballot) {
2858 if (!
getType().isInteger(32)) {
2859 return emitOpError(
"vote.sync 'ballot' returns an i32");
2862 if (!
getType().isInteger(1)) {
2863 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
2869LogicalResult NVVM::PrefetchOp::verify() {
2870 using MemSpace = NVVM::NVVMMemorySpace;
2871 using CacheLevel = NVVM::PrefetchCacheLevel;
2873 unsigned addressSpace =
2874 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
2875 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2876 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2878 if (getTensormap() && cacheLevel)
2879 return emitOpError(
"cannot specify both tensormap and cache level");
2881 if (getTensormap()) {
2882 if (addressSpace != MemSpace::Generic &&
2883 addressSpace != MemSpace::Constant) {
2885 "prefetch tensormap requires a generic or constant pointer");
2888 if (evictPriority) {
2890 "prefetch tensormap does not support eviction priority");
2893 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2895 "in_param_space can only be specified for a generic pointer");
2898 }
else if (cacheLevel) {
2899 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2900 addressSpace != MemSpace::Local) {
2901 return emitOpError(
"prefetch to cache level requires a generic, global, "
2902 "or local pointer");
2906 if (*cacheLevel != CacheLevel::L1) {
2908 "unsupported cache level, the only supported uniform "
2909 "cache level is L1");
2912 if (addressSpace != MemSpace::Generic) {
2914 "prefetch to uniform cache requires a generic pointer");
2918 if (evictPriority) {
2919 if (*cacheLevel != CacheLevel::L2)
2921 "cache eviction priority supported only for cache level L2");
2923 if (addressSpace != MemSpace::Global)
2924 return emitOpError(
"cache eviction priority requires a global pointer");
2926 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2927 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2929 "unsupported cache eviction priority, only evict_last and "
2930 "evict_normal are supported");
2934 return emitOpError(
"predicate supported only on prefetch tensormap");
2938 "requires specification of either cache level or tensormap");
2944LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2945 switch (getQueryType()) {
2946 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2948 return emitOpError(
"is_canceled query type returns an i1");
2950 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2951 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2952 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2953 if (!
getType().isInteger(32)) {
2954 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
2955 "get_first_cta_id_z query types return an i32");
2962LogicalResult NVVM::ReduxOp::verify() {
2965 if (!reduxType.
isF32()) {
2967 return emitOpError(
"abs attribute is supported only for f32 type");
2969 return emitOpError(
"nan attribute is supported only for f32 type");
2972 NVVM::ReductionKind kind = getKind();
2974 case NVVM::ReductionKind::ADD:
2975 case NVVM::ReductionKind::AND:
2976 case NVVM::ReductionKind::OR:
2977 case NVVM::ReductionKind::XOR:
2978 case NVVM::ReductionKind::MAX:
2979 case NVVM::ReductionKind::MIN:
2980 case NVVM::ReductionKind::UMAX:
2981 case NVVM::ReductionKind::UMIN:
2984 << kind <<
"' reduction kind unsupported with " << reduxType
2985 <<
" type. Only supported type is 'i32'.";
2987 case NVVM::ReductionKind::FMIN:
2988 case NVVM::ReductionKind::FMAX:
2989 if (!reduxType.isF32())
2991 << kind <<
"' reduction kind unsupported with " << reduxType
2992 <<
" type. Only supported type is 'f32'.";
2999LogicalResult NVVM::TensormapReplaceOp::verify() {
3000 auto ord = getOrd();
3001 Value newVal = getNewValue();
3002 auto newValAttr = getNewValueAttr();
3003 auto fieldName = stringifyEnum(getField());
3005 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3006 NVVM::TensormapField::GLOBAL_DIM,
3007 NVVM::TensormapField::GLOBAL_STRIDE,
3008 NVVM::TensormapField::ELEMENT_STRIDE},
3010 return emitOpError(
"ordinal is not supported for ")
3011 << fieldName <<
" field";
3013 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3014 return llvm::Twine(
"new_value must be specified and must be an " + type +
3015 " for " + llvm::Twine(fieldName) +
" field")
3019 auto invalidNewValAttr = [&]() -> std::string {
3020 return (llvm::Twine(
3021 "new_value_attr must be specified and must be a valid ") +
3022 llvm::Twine(fieldName) +
" attribute for " + fieldName +
" field")
3026 switch (getField()) {
3027 case NVVM::TensormapField::GLOBAL_ADDRESS:
3031 case NVVM::TensormapField::RANK:
3035 case NVVM::TensormapField::GLOBAL_STRIDE:
3037 return emitOpError(
"ordinal is required for global_stride field");
3041 case NVVM::TensormapField::BOX_DIM:
3042 case NVVM::TensormapField::GLOBAL_DIM:
3043 case NVVM::TensormapField::ELEMENT_STRIDE:
3046 << stringifyEnum(getField()) <<
" field";
3050 case NVVM::TensormapField::ELEMTYPE:
3051 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3054 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3055 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3058 case NVVM::TensormapField::SWIZZLE_MODE:
3059 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3062 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3063 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3066 case NVVM::TensormapField::FILL_MODE:
3067 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3081 unsigned sizeInBits,
3083 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3085 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3086 if (mask != 0xffffffffu)
3087 field = builder.CreateAnd(field, builder.getInt32(mask));
3089 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3090 field = builder.CreateShl(field, start);
3092 return builder.CreateOr(
result, field);
3095void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3097 llvm::IRBuilderBase &builder) {
3098 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3099 llvm::Value *smemDesc = builder.getInt64(0);
3104 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3106 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3112 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3116 mt.
mapValue(thisOp.getRes()) = smemDesc;
3123std::string NVVM::MBarrierInitOp::getPtx() {
3125 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3126 : std::string(
"mbarrier.init.b64 [%0], %1;");
3129std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3132 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3133 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3136std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3138 llvm::StringRef space = isShared ?
".shared" :
"";
3140 return llvm::formatv(
"{\n\t"
3141 ".reg .pred P1; \n\t"
3143 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3144 "@P1 bra.uni DONE; \n\t"
3145 "bra.uni LAB_WAIT; \n\t"
3157 auto thisOp = cast<NVVM::BarrierOp>(op);
3158 llvm::Value *barrierId = thisOp.getBarrierId()
3160 : builder.getInt32(0);
3161 llvm::Intrinsic::ID id;
3163 if (thisOp.getNumberOfThreads()) {
3164 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3165 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3166 }
else if (thisOp.getReductionOp()) {
3167 switch (*thisOp.getReductionOp()) {
3168 case NVVM::BarrierReduction::AND:
3169 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3171 case NVVM::BarrierReduction::OR:
3172 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3174 case NVVM::BarrierReduction::POPC:
3175 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3178 args.push_back(builder.CreateICmpNE(
3179 mt.
lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3181 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3184 return {id, std::move(args)};
3189 llvm::IRBuilderBase &builder) {
3190 auto thisOp = cast<NVVM::PMEventOp>(op);
3194 llvm::Value *maskVal;
3195 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3196 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3197 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3200 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3203 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3208 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3210 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3211 : llvm::Intrinsic::nvvm_mbarrier_init;
3216 args.push_back(mt.
lookupValue(thisOp.getCount()));
3218 return {id, std::move(args)};
3223 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3225 llvm::Intrinsic::ID
id = isShared
3226 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3227 : llvm::Intrinsic::nvvm_mbarrier_inval;
3234 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3237 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3240 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3242 static constexpr llvm::Intrinsic::ID IDs[] = {
3243 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3244 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3245 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3246 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3251 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3253 return {IDs[
index], std::move(args)};
3258 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3261 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3264 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3266 static constexpr llvm::Intrinsic::ID IDs[] = {
3267 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3268 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3269 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3270 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3275 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3277 return {IDs[
index], std::move(args)};
3282 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3285 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3288 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3290 static constexpr llvm::Intrinsic::ID IDs[] = {
3291 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3292 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3293 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3294 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3295 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3296 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3297 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3298 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3300 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3301 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3305 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3312 bool hasCount =
static_cast<bool>(thisOp.getCount());
3314 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3315 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3319 llvm::Value *count =
3321 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3322 return {id, {mbar, count}};
3327 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3330 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3333 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3335 static constexpr llvm::Intrinsic::ID IDs[] = {
3336 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3337 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3338 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3339 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3340 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3341 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3343 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3345 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3347 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3348 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3352 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3358 bool hasCount =
static_cast<bool>(thisOp.getCount());
3359 llvm::Value *count =
3361 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3363 return {id, {mbar, count}};
3366bool MBarrierArriveExpectTxOp::getAsmValues(
3373 for (
auto val : getOperands())
3381 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3384 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3387 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3390 static constexpr llvm::Intrinsic::ID IDs[] = {
3391 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3392 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3393 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3394 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3395 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3396 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3397 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3398 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3399 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3401 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3404 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3405 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3410 return {id, {mbar, txcount}};
3415 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3418 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3421 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3424 static constexpr llvm::Intrinsic::ID IDs[] = {
3425 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3426 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3427 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3428 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3429 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3430 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3431 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3432 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3433 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3435 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3438 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3439 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3444 return {id, {mbar, txcount}};
3449 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3451 llvm::Intrinsic::ID
id =
3452 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3453 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3457 args.push_back(mt.
lookupValue(thisOp.getCount()));
3459 return {id, std::move(args)};
3464 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3466 llvm::Intrinsic::ID
id =
3467 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3468 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3472 args.push_back(mt.
lookupValue(thisOp.getCount()));
3474 return {id, std::move(args)};
3479 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3480 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3481 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3484 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3487 static constexpr llvm::Intrinsic::ID IDs[] = {
3488 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3489 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3490 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3491 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3492 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3493 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3494 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3495 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3496 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3498 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3501 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3502 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3507 return {id, {mbar, input}};
3512 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3513 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3514 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3515 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3519 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3520 (isPhaseParity ? 1 : 0);
3523 static constexpr llvm::Intrinsic::ID IDs[] = {
3524 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3525 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3526 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3527 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3528 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3529 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3530 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3531 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3532 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3533 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3534 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3535 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3536 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3537 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3538 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3539 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3540 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3542 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3545 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3552 args.push_back(mbar);
3553 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
3555 args.push_back(mt.
lookupValue(thisOp.getTicks()));
3557 return {id, std::move(args)};
3562 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3565 llvm::Intrinsic::ID id;
3566 if (thisOp.getNoinc()) {
3567 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3568 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3570 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3571 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3577#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3578 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3580#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3581 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3586 llvm::Intrinsic::ID id;
3588 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3589 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
3590 switch (cpAsyncOp.getSize()) {
3598 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3603 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
3607 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
3608 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
3610 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
3617 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3619 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3622 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3626 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3627 llvm::Value *i64Unused =
3628 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3629 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3630 args.push_back(builder.getInt1(hasCacheHint));
3632 return {id, std::move(args)};
3637 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3641 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3643 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3647 mlir::Value multicastMask = thisOp.getMulticastMask();
3648 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
3651 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3652 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
3658 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3659 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3660 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3664 args.push_back(builder.getInt1(hasMulticastMask));
3665 args.push_back(builder.getInt1(hasCacheHint));
3667 llvm::Intrinsic::ID
id =
3669 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3670 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3672 return {id, std::move(args)};
3677 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3679 llvm::Intrinsic::ID
id =
3680 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3683 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3684 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3688 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3689 llvm::Value *i64Unused =
3690 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3691 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3692 args.push_back(builder.getInt1(hasCacheHint));
3695 if (
mlir::Value byteMask = thisOp.getByteMask()) {
3697 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3700 return {id, std::move(args)};
3703bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3710 for (
auto val : getOperands())
3717CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3719 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3720 const bool isCTAOnly = thisOp.getIsCTAOnly();
3724 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3726 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3736 const bool hasMC =
static_cast<bool>(mcMask);
3737 llvm::Value *i16Zero =
3738 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
3742 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3743 llvm::Value *i64Zero =
3744 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3750 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
3752 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
3756 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
3757 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3758 args.push_back(builder.getInt1(hasMC));
3759 args.push_back(builder.getInt1(hasCacheHint));
3763 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3764 args.push_back(builder.getInt1(hasCacheHint));
3767 constexpr size_t numDims = 5;
3768 constexpr size_t numModes = 5;
3769 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3770 using TableTy = std::array<rowTy, numModes>;
3771 static constexpr TableTy IDTable{
3772 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3773 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3774 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3775 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3776 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3778 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3779 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3780 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3782 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3783 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3784 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3786 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3787 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3788 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3790 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3792 static constexpr TableTy IDTableCTA{
3794 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3795 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3796 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3797 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3798 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3800 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3801 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3802 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3804 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3805 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3806 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3808 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3809 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3810 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3812 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3815 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3816 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3817 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3818 size_t mode =
static_cast<size_t>(thisOp.getMode());
3819 size_t dim = thisOp.getCoordinates().size();
3820 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3822 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3824 return {id, std::move(args)};
3829 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3833 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3835 for (
auto v : thisOp.getCoordinates())
3837 for (
auto v : thisOp.getIm2colOffsets())
3841 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3842 llvm::Value *i64Unused =
3843 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3844 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3845 args.push_back(builder.getInt1(hasCacheHint));
3847 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3848 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3849 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3850 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3851 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3852 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3853 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3855 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3856 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3857 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3859 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3860 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3861 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3863 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3864 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3865 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3866 {NI, NI, NI, NI, NI,
3867 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3869 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3870 "TMALoadModes must match number of rows in IDTable");
3871 size_t mode =
static_cast<size_t>(thisOp.getMode());
3872 size_t dim = thisOp.getCoordinates().size();
3873 llvm::Intrinsic::ID
id = IDTable[mode][dim];
3874 if (
id == llvm::Intrinsic::not_intrinsic)
3875 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3877 return {id, std::move(args)};
3881CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3883 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3887 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3888 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3890 for (
auto v : thisOp.getCoordinates())
3894 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3895 llvm::Value *i64Unused =
3896 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3897 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3898 args.push_back(builder.getInt1(hasCacheHint));
3900 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3901 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3902 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
3903 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
3907 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
3909 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
3910 {NI, NI, NI, NI, NI,
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
3913 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
3914 "TMAStoreModes must match number of rows in IDTable");
3915 size_t mode =
static_cast<size_t>(thisOp.getMode());
3916 size_t dim = thisOp.getCoordinates().size();
3917 llvm::Intrinsic::ID
id = IDTable[mode][dim];
3918 if (
id == llvm::Intrinsic::not_intrinsic)
3920 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
3922 return {id, std::move(args)};
3927 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
3935 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3936 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3938 for (
Value v : thisOp.getCoordinates())
3942 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3943 llvm::Value *i64ZeroValue =
3944 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
3945 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
3946 args.push_back(builder.getInt1(hasCacheHint));
3948 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
3950 constexpr unsigned numRedKinds = 8;
3951 constexpr unsigned numLayouts = 2;
3952 constexpr unsigned maxDim = 5;
3953 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
3954 using layoutTable = std::array<row, numLayouts>;
3955 using fullTable = std::array<layoutTable, numRedKinds>;
3956 static constexpr fullTable IDTable{
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
3961 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
3965 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
3971 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
3972 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
3973 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
3974 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
3976 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
3977 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
3978 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
3981 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
3982 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
3983 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
3984 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
3985 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
3987 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
3988 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
3989 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
3992 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
3993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
3994 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
3995 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
3996 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
3998 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
3999 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4000 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4003 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4004 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4005 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4006 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4007 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4009 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4010 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4011 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4014 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4015 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4016 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4017 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4018 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4020 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4021 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4022 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4025 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4026 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4027 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4028 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4029 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4031 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4032 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4033 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4036 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4037 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4038 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4039 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4040 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4042 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4043 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4045 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4047 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4048 "TMAReduxKinds must match number of rows in IDTable");
4050 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
4051 size_t mode =
static_cast<size_t>(thisOp.getMode());
4052 size_t dim = thisOp.getCoordinates().size();
4054 assert(redKind < IDTable.size() &&
4055 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4056 assert(mode < IDTable[redKind].size() &&
4057 "Invalid mode for CpAsyncBulkTensorReduceOp");
4058 assert(dim < IDTable[redKind][mode].size() &&
4059 "Invalid dim for CpAsyncBulkTensorReduceOp");
4061 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4064 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4066 return {intrinsicID, std::move(args)};
4071#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4072 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4073 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4075#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4076 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4077 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4080ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4081 NVVM::SaturationMode sat,
bool hasRelu) {
4082 using RndMode = NVVM::FPRoundingMode;
4083 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4092 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4097ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4099 llvm::IRBuilderBase &builder) {
4104 bool hasRelu = op.getRelu();
4106 llvm::Intrinsic::ID intId =
4107 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4108 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4110 return {intId, std::move(args)};
4113#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4114 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4115 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4117llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4120 .Case([&](mlir::Float6E2M3FNType) {
4123 .Case([&](mlir::Float6E3M2FNType) {
4127 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4128 return llvm::Intrinsic::not_intrinsic;
4132#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4133 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4134 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4136#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4137 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4138 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4141ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4142 NVVM::SaturationMode sat,
bool hasRelu) {
4143 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4144 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4145 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4148 .Case([&](mlir::Float8E4M3FNType) {
4151 .Case([&](mlir::Float8E5M2Type) {
4154 .Case([&](mlir::Float8E8M0FNUType) {
4155 if (hasRoundingModeRZ)
4157 else if (hasRoundingModeRP)
4160 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4163 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4164 return llvm::Intrinsic::not_intrinsic;
4168#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4169 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4170 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4172llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4175 .Case([&](mlir::Float8E4M3FNType) {
4178 .Case([&](mlir::Float8E5M2Type) {
4182 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4183 return llvm::Intrinsic::not_intrinsic;
4187#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4188 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4189 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4192ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4193 NVVM::SaturationMode sat) {
4194 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4196 case NVVM::FPRoundingMode::RZ:
4198 case NVVM::FPRoundingMode::RP:
4201 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
4207 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4209 bool hasRelu = curOp.getRelu();
4211 llvm::Intrinsic::ID intId =
4213 .Case([&](Float8E4M3FNType type) {
4214 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4215 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4217 .Case([&](Float8E5M2Type type) {
4218 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4219 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4222 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4223 return llvm::Intrinsic::not_intrinsic;
4226 llvm::Value *packedI16 =
4227 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4228 llvm::Type::getInt16Ty(builder.getContext()));
4230 return {intId, {packedI16}};
4235 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4237 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4238 llvm::Value *packedI16 =
4239 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4240 llvm::Type::getInt16Ty(builder.getContext()));
4242 return {intId, {packedI16}};
4247 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4249 bool hasRelu = curOp.getRelu();
4251 llvm::Intrinsic::ID intId =
4253 .Case([&](Float6E2M3FNType type) {
4254 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4255 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4257 .Case([&](Float6E3M2FNType type) {
4258 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4259 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4262 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4263 return llvm::Intrinsic::not_intrinsic;
4266 llvm::Value *packedI16 =
4267 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4268 llvm::Type::getInt16Ty(builder.getContext()));
4270 return {intId, {packedI16}};
4275 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4277 bool hasRelu = curOp.getRelu();
4279 llvm::Intrinsic::ID intId =
4281 .Case([&](Float4E2M1FNType type) {
4282 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4283 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4286 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4287 return llvm::Intrinsic::not_intrinsic;
4290 llvm::Value *extendedI16 =
4291 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4292 llvm::Type::getInt16Ty(builder.getContext()));
4294 return {intId, {extendedI16}};
4298Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
4301 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4302 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4304 bool isShared = as == NVVMMemorySpace::Shared;
4305 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4307 llvm::Intrinsic::ID id;
4309 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4310 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4312 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4313 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4323llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4326 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4327 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4328 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4329 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4338#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4339 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4340 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4342#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4343 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4344 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4347Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
4350 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4351 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4353 bool isShared = as == NVVMMemorySpace::Shared;
4354 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
4355 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4357 llvm::Intrinsic::ID
id =
4364 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
4369#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4370 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4372#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4373 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4374 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4376#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4378 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4379 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4380 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4381 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4382 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4386ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4388 llvm::IRBuilderBase &builder) {
4389 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4390 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4391 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4392 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4393 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4395 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4396 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4397 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4398 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4399 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4401 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4402 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4403 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4404 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4405 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4408 unsigned hasRelu = op.getRelu() ? 1 : 0;
4409 unsigned hasSatFinite =
4410 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4413 unsigned idx = (hasSatFinite << 1) | hasRelu;
4418 if (op.getRandomBits())
4419 args.push_back(mt.
lookupValue(op.getRandomBits()));
4421 switch (op.getRnd()) {
4422 case FPRoundingMode::RN:
4423 return {rndRNIds[idx], std::move(args)};
4424 case FPRoundingMode::RZ:
4425 return {rndRZIds[idx], std::move(args)};
4426 case FPRoundingMode::RS:
4427 return {rndRSIds[idx], std::move(args)};
4429 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
4434ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4436 llvm::IRBuilderBase &builder) {
4437 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4438 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4439 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4440 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4441 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4443 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4444 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4445 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4446 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4447 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4449 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4450 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4451 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4452 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4453 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4456 unsigned hasRelu = op.getRelu() ? 1 : 0;
4457 unsigned hasSatFinite =
4458 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4461 unsigned idx = (hasSatFinite << 1) | hasRelu;
4466 if (op.getRandomBits())
4467 args.push_back(mt.
lookupValue(op.getRandomBits()));
4469 switch (op.getRnd()) {
4470 case FPRoundingMode::RN:
4471 return {rndRNIds[idx], std::move(args)};
4472 case FPRoundingMode::RZ:
4473 return {rndRZIds[idx], std::move(args)};
4474 case FPRoundingMode::RS:
4475 return {rndRSIds[idx], std::move(args)};
4477 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4481llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4483 bool hasRelu = getRelu();
4486 .Case([&](mlir::Float8E4M3FNType) {
4487 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4488 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4490 .Case([&](mlir::Float8E5M2Type) {
4491 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4492 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4495 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
4496 return llvm::Intrinsic::not_intrinsic;
4500llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4502 bool hasRelu = getRelu();
4505 .Case([&](mlir::Float6E2M3FNType) {
4506 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4507 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4509 .Case([&](mlir::Float6E3M2FNType) {
4510 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4511 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4514 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
4515 return llvm::Intrinsic::not_intrinsic;
4519llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4521 bool hasRelu = getRelu();
4524 .Case([&](mlir::Float4E2M1FNType) {
4525 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4526 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4529 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
4530 return llvm::Intrinsic::not_intrinsic;
4534llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
4535 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4536 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4537 auto srcFmt = curOp.getSrcFormat();
4538 auto mc = curOp.getMulticast();
4540 switch (curOp.getShape()) {
4541 case Tcgen05CpShape::SHAPE_128x256b:
4543 case Tcgen05CpShape::SHAPE_128x128b:
4545 case Tcgen05CpShape::SHAPE_4x256b:
4547 case Tcgen05CpShape::SHAPE_32x128b:
4549 case Tcgen05CpShape::SHAPE_64x128b:
4550 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4554 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
4561 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4563 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4568LogicalResult Tcgen05LdOp::verify() {
4570 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4573 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4574 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
4576 auto resTy = getRes().getType();
4577 unsigned resLen = isa<VectorType>(resTy)
4578 ? llvm::cast<VectorType>(resTy).getNumElements()
4581 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
4582 "{1} in tcgen05.ld Op",
4583 resLen, stringifyEnum(
getShape())));
4588LogicalResult Tcgen05StOp::verify() {
4590 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4593 auto valTy = getVal().getType();
4594 unsigned valLen = isa<VectorType>(valTy)
4595 ? llvm::cast<VectorType>(valTy).getNumElements()
4598 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
4599 "{1} in tcgen05.st Op",
4600 valLen, stringifyEnum(
getShape())));
4610 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
4611 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4612 rangeAttr.getLower(), rangeAttr.getUpper()});
4622 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4626 const llvm::APInt &lower = rangeAttr->getLower();
4627 const llvm::APInt &upper = rangeAttr->getUpper();
4630 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4631 unsigned bitWidth = lower.getBitWidth();
4632 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4633 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4635 "invalid range attribute: Lower == Upper, but they aren't min (")
4636 << llvm::toString(minVal, 10,
false) <<
") or max ("
4637 << llvm::toString(maxVal, 10,
false)
4638 <<
") value! This is an invalid constant range.";
4645 llvm::IRBuilderBase &builder) {
4646 return builder.CreateBitCast(arg,
4647 llvm::Type::getInt32Ty(builder.getContext()));
4652 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4659 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4660 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4661 unsigned type = (isASigned << 1) | isBSigned;
4662 const llvm::Intrinsic::ID ids[] = {
4663 llvm::Intrinsic::nvvm_idp4a_u_u,
4664 llvm::Intrinsic::nvvm_idp4a_u_s,
4665 llvm::Intrinsic::nvvm_idp4a_s_u,
4666 llvm::Intrinsic::nvvm_idp4a_s_s,
4668 return {ids[type], args};
4673 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4678 args.push_back(builder.getInt1(curOp.getBHi()));
4681 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4682 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4683 unsigned type = (isASigned << 1) | isBSigned;
4684 const llvm::Intrinsic::ID ids[] = {
4685 llvm::Intrinsic::nvvm_idp2a_u_u,
4686 llvm::Intrinsic::nvvm_idp2a_u_s,
4687 llvm::Intrinsic::nvvm_idp2a_s_u,
4688 llvm::Intrinsic::nvvm_idp2a_s_s,
4690 return {ids[type], args};
4694 llvm::IRBuilderBase &builder) {
4695 return builder.CreateAddrSpaceCast(
4697 llvm::PointerType::get(builder.getContext(),
4698 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4702PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4704 llvm::IRBuilderBase &builder) {
4705 using MemSpace = NVVM::NVVMMemorySpace;
4706 using CacheLevel = NVVM::PrefetchCacheLevel;
4708 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4709 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4710 op.getEvictPriority();
4711 unsigned addressSpace =
4712 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4720 if (op.getTensormap())
4721 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4723 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
4725 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4726 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4728 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4729 switch (*evictPriority) {
4730 case NVVM::CacheEvictionPriority::EvictLast:
4731 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4732 case NVVM::CacheEvictionPriority::EvictNormal:
4733 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4735 llvm_unreachable(
"Invalid cache eviction priority");
4739 switch (
static_cast<MemSpace
>(addressSpace)) {
4740 case MemSpace::Generic:
4741 return *cacheLevel == CacheLevel::L1
4743 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4744 case MemSpace::Global:
4745 return *cacheLevel == CacheLevel::L1
4747 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4749 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4750 case MemSpace::Local:
4751 return *cacheLevel == CacheLevel::L1
4753 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4755 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4757 llvm_unreachable(
"Invalid pointer address space");
4761bool NVVM::InlinePtxOp::getAsmValues(
4765 for (
auto arg : getReadWriteArgs())
4767 for (
auto arg : getResults())
4769 for (
auto arg : getReadOnlyArgs())
4776NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4778 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4780 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
4781 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
4783 llvm::Intrinsic::ID intrinsicID =
4784 curOp.getMulticast()
4786 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4787 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4789 return {intrinsicID, args};
4792NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4794 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4796 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
4798 llvm::Intrinsic::ID intrinsicID;
4800 switch (curOp.getQueryType()) {
4801 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4803 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4805 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4806 intrinsicID = llvm::Intrinsic::
4807 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4809 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4810 intrinsicID = llvm::Intrinsic::
4811 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4813 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4814 intrinsicID = llvm::Intrinsic::
4815 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4818 return {intrinsicID, args};
4823 llvm::IRBuilderBase &builder) {
4824 auto thisOp = cast<NVVM::PermuteOp>(op);
4825 NVVM::PermuteMode mode = thisOp.getMode();
4827 static constexpr llvm::Intrinsic::ID IDs[] = {
4828 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4829 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4830 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4831 llvm::Intrinsic::nvvm_prmt_rc16};
4833 unsigned modeIndex =
static_cast<unsigned>(mode);
4841 args.push_back(mt.
lookupValue(thisOp.getSelector()));
4843 return {IDs[modeIndex], args};
4848 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
4852 if (thisOp.getOrd())
4853 args.push_back(builder.getInt32(thisOp.getOrd().value()));
4854 if (thisOp.getNewValue())
4855 args.push_back(mt.
lookupValue(thisOp.getNewValue()));
4856 if (
auto attr = thisOp.getNewValueAttr()) {
4859 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
4860 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
4861 TensormapFillModeAttr>([](
auto attr) {
4862 return static_cast<unsigned>(attr.getValue());
4864 .Default([](
auto attr) {
4865 llvm_unreachable(
"Invalid attribute type");
4868 args.push_back(builder.getInt32(val));
4871 static constexpr llvm::Intrinsic::ID IDs[] = {
4872 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
4873 llvm::Intrinsic::nvvm_tensormap_replace_rank,
4874 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
4875 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
4876 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
4877 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
4878 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
4879 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
4880 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
4881 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
4882 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
4885 unsigned fieldIndex =
static_cast<unsigned>(thisOp.getField());
4887 return {IDs[fieldIndex], args};
4896 llvm::IRBuilderBase &builder) {
4898 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
4901 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
4904 const bool isATensor = isa<llvm::PointerType>(
A->getType());
4907 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
4908 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
4909 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
4911 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4912 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4913 using IsATensorArray = std::array<CtaGroupArray, 2>;
4914 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4915 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4918 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
4924 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
4926 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
4930 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4931 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4935 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4936 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4942 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
4944 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
4948 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4949 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4953 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4954 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4960 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
4963 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
4968 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
4970 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
4975 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
4977 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
4983 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
4987 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
4992 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
4994 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
4998 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5000 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5003 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5004 bool hasScaleInputD = ScaleInputD !=
nullptr;
5006 llvm::Value *DisableOutputLane =
5008 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5010 const unsigned ctaGroup =
5013 llvm::Intrinsic::ID ID =
5014 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5015 [ctaGroup - 1][thisOp.getAShift()];
5017 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
5020 args.push_back(ScaleInputD);
5022 if (hasDisableOutputLane)
5023 args.push_back(DisableOutputLane);
5025 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5027 if (!hasDisableOutputLane)
5028 args.push_back(builder.getInt32(ctaGroup));
5031 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5038 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
5039 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
5041 if (disableOutputLane) {
5042 mlir::VectorType disableOutputLaneType =
5043 cast<mlir::VectorType>(disableOutputLane.
getType());
5044 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5045 disableOutputLaneType.getNumElements() != 4) ||
5046 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5047 disableOutputLaneType.getNumElements() != 8))
5048 return emitError(loc) <<
"Disable Output Lane of length "
5049 << disableOutputLaneType.getNumElements()
5050 <<
" is incompatible with CtaGroupAttr";
5053 if (hasAShift && !isATensor)
5055 loc,
"A-shift can be applied only when matrix A is in tensor memory");
5057 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5058 collectorOp == Tcgen05MMACollectorOp::USE))
5060 loc,
"Cannot use collector buffer operation fill or use with ashift");
5065LogicalResult Tcgen05MMAOp::verify() {
5067 getDisableOutputLane(), getCtaGroup(), getAShift(),
5068 getCollectorOp(), getLoc());
5078 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5081 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5084 bool isATensor = isa<llvm::PointerType>(
A->getType());
5087 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5088 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5089 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5090 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5092 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5093 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5094 using IsATensorArray = std::array<CtaGroupArray, 2>;
5095 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5096 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5099 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5105 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
5107 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
5111 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5112 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5116 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5117 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5123 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5126 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5131 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5132 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5136 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5137 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5144 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5148 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5153 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5155 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5160 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5162 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5168 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5172 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5177 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5179 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5183 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5185 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5188 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5189 bool hasScaleInputD = ScaleInputD !=
nullptr;
5191 llvm::Value *DisableOutputLane =
5193 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5198 llvm::Intrinsic::ID ID =
5199 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5200 [ctaGroup - 1][thisOp.getAShift()];
5202 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5205 args.push_back(ScaleInputD);
5207 if (hasDisableOutputLane)
5208 args.push_back(DisableOutputLane);
5210 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5212 if (!hasDisableOutputLane)
5213 args.push_back(builder.getInt32(ctaGroup));
5216 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5221LogicalResult Tcgen05MMASparseOp::verify() {
5223 getDisableOutputLane(), getCtaGroup(), getAShift(),
5224 getCollectorOp(), getLoc());
5234 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5237 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5240 bool isATensor = isa<llvm::PointerType>(
A->getType());
5243 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5244 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5245 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5246 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5247 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5248 args.push_back(builder.getInt32(
5251 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5253 auto kind = thisOp.getKind();
5254 auto blockScale = thisOp.getBlockScale();
5255 llvm::Intrinsic::ID ID = [&]() {
5256 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5257 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5258 return isATensor ? llvm::Intrinsic::
5259 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5261 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5262 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5265 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5267 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5269 }
else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5270 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5272 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5273 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5274 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5275 return isATensor ? llvm::Intrinsic::
5276 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5278 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5280 }
else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5281 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5284 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5286 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5288 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5291 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5293 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5296 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
5303 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
5304 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
5306 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5307 kind == MMABlockScaleKind::MXF4NVF4)
5308 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
5310 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5311 kind != MMABlockScaleKind::MXF4NVF4)
5313 llvm::formatv(
"{} kind does not support block16 attribute",
5314 stringifyEnum(kind)));
5319LogicalResult Tcgen05MMABlockScaleOp::verify() {
5321 getBlockScale(), getLoc());
5331 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5334 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5337 bool isATensor = isa<llvm::PointerType>(
A->getType());
5340 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5341 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5342 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5343 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5344 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5345 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5346 args.push_back(builder.getInt32(
5349 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5351 auto kind = thisOp.getKind();
5352 auto blockScale = thisOp.getBlockScale();
5353 llvm::Intrinsic::ID ID = [&]() {
5354 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5355 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5356 return isATensor ? llvm::Intrinsic::
5357 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5359 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5360 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5363 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5365 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5367 }
else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5368 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5369 return isATensor ? llvm::Intrinsic::
5370 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5372 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5373 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5376 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5378 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5380 }
else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5381 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5384 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5386 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5388 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5391 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5393 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5396 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
5402LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5404 getBlockScale(), getLoc());
5414 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5417 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5420 bool isATensor = isa<llvm::PointerType>(
A->getType());
5423 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5424 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5425 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5427 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5431 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5432 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5434 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5435 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5437 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5439 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5441 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5453 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5456 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5459 bool isATensor = isa<llvm::PointerType>(
A->getType());
5462 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5463 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5464 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5465 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5467 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5472 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5473 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5475 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5476 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5478 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5480 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5482 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5491#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
5492 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
5496 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
5499 mlir::VectorType VecResTy =
5500 cast<mlir::VectorType>(thisOp.getData().getType());
5501 unsigned Num = VecResTy.getNumElements();
5502 bool IsFloat = thisOp.getRedVal().getType().isF32();
5504 llvm::Intrinsic::ID Shape32x32b[][2] = {
5515 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
5526 NVVM::Tcgen05LdStShape
shape = thisOp.getShape();
5527 unsigned ID = [&]() {
5530 unsigned idx = std::log2(Num);
5532 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
5533 return Shape32x32b[idx][IsFloat];
5534 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
5535 return Shape16x32bx2[idx][IsFloat];
5537 llvm_unreachable(
"unhandled tcgen05.ld lowering");
5543 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
5544 args.push_back(mt.
lookupValue(thisOp.getOffset()));
5547 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
5550 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getAbs())));
5551 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getNan())));
5556LogicalResult Tcgen05LdRedOp::verify() {
5557 VectorType data = cast<VectorType>(getData().
getType());
5558 Type redVal = getRedVal().getType();
5560 if (data.getElementType() != redVal)
5562 "type of reduction value and element type of vector data should match");
5564 if (getOp() != NVVM::ReductionKind::MIN &&
5565 getOp() != NVVM::ReductionKind::MAX)
5566 return emitError(
"only min and max reduction kinds are supported");
5568 if (redVal.
isInteger() && (getAbs() || getNan())) {
5569 return emitError(
"abs or nan is only applicable for f32 type");
5579void NVVMDialect::initialize() {
5582#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5585#define GET_ATTRDEF_LIST
5586#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5591 allowUnknownOperations();
5592 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5593 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5596LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
5598 StringAttr attrName = attr.
getName();
5600 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5601 if (!isa<LLVM::LLVMFuncOp>(op)) {
5602 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
5603 <<
"' attribute attached to unexpected op";
5608 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5609 attrName == NVVMDialect::getReqntidAttrName() ||
5610 attrName == NVVMDialect::getClusterDimAttrName()) {
5611 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
5612 if (!values || values.empty() || values.size() > 3) {
5615 <<
"' attribute must be integer array with maximum 3 index";
5620 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5621 attrName == NVVMDialect::getMaxnregAttrName() ||
5622 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5623 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
5625 <<
"'" << attrName <<
"' attribute must be integer constant";
5629 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5630 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
5631 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
5633 <<
"'" << attrName <<
"' attribute must be used along with "
5634 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
5635 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
5642LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
5643 unsigned regionIndex,
5646 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5650 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
5651 StringAttr attrName = argAttr.
getName();
5652 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5656 <<
"' attribute must be present only on kernel arguments";
5658 if (!isa<UnitAttr>(argAttr.
getValue()))
5659 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
5660 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5663 <<
"' attribute requires the argument to also have attribute '"
5664 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
5675unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
5676 return static_cast<unsigned>(getValue());
5679bool NVVMMemorySpaceAttr::isValidLoad(
5680 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5681 const ::mlir::DataLayout *dataLayout,
5687bool NVVMMemorySpaceAttr::isValidStore(
5688 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5689 const ::mlir::DataLayout *dataLayout,
5695bool NVVMMemorySpaceAttr::isValidAtomicOp(
5696 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
5697 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5700 assert(
false &&
"unimplemented, see TODO in the source.");
5704bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5705 Type type, ptr::AtomicOrdering successOrdering,
5706 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5707 const ::mlir::DataLayout *dataLayout,
5710 assert(
false &&
"unimplemented, see TODO in the source.");
5714bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5718 assert(
false &&
"unimplemented, see TODO in the source.");
5722bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5727 assert(
false &&
"unimplemented, see TODO in the source.");
5736 int optLevel, StringRef triple, StringRef chip,
5737 StringRef features, DictionaryAttr flags,
5739 if (optLevel < 0 || optLevel > 3) {
5740 emitError() <<
"The optimization level must be a number between 0 and 3.";
5743 if (triple.empty()) {
5744 emitError() <<
"The target triple cannot be empty.";
5748 emitError() <<
"The target chip cannot be empty.";
5751 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5752 return mlir::isa_and_nonnull<StringAttr>(attr);
5754 emitError() <<
"All the elements in the `link` array must be strings.";
5760LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
5761 if (!getVerifyTarget())
5764 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5767 "NVVM target attribute must be attached to a GPU module");
5770 const NVVMCheckSMVersion targetSMVersion =
5774 "Minimum NVVM target SM version is sm_20");
5778 ->
walk([&](Operation *op) {
5779 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5780 const NVVMCheckSMVersion requirement =
5781 reqOp.getRequiredMinSMVersion();
5783 op->
emitOpError() <<
"is not supported on " << getChip();
5795#define GET_OP_CLASSES
5796#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5798#define GET_ATTRDEF_CLASSES
5799#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.