31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
315 inferredReturnTypes.push_back(IntegerType::get(context, 64));
320MBarrierArriveOp::inferReturnTypes(
MLIRContext *context,
321 std::optional<Location> location,
322 MBarrierArriveOp::Adaptor adaptor,
325 inferredReturnTypes);
328LogicalResult MBarrierArriveDropOp::inferReturnTypes(
329 MLIRContext *context, std::optional<Location> location,
330 MBarrierArriveDropOp::Adaptor adaptor,
333 inferredReturnTypes);
336LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
337 MLIRContext *context, std::optional<Location> location,
338 MBarrierArriveExpectTxOp::Adaptor adaptor,
342 if (adaptor.getPredicate())
345 inferredReturnTypes);
348LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> location,
350 MBarrierArriveDropExpectTxOp::Adaptor adaptor,
353 inferredReturnTypes);
363 return inferred == actual;
372bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(
TypeRange l,
376bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(
TypeRange l,
381LogicalResult MBarrierExpectTxOp::verify() {
385LogicalResult MBarrierCompleteTxOp::verify() {
389LogicalResult MBarrierTestWaitOp::verify() {
393LogicalResult MBarrierTryWaitOp::verify() {
397LogicalResult ConvertFloatToTF32Op::verify() {
398 using RndMode = NVVM::FPRoundingMode;
402 return emitError(
"Relu not supported with rna rounding mode.");
409 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
414LogicalResult ConvertF32x2ToF6x2Op::verify() {
417 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
419 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
420 << mlir::Float6E3M2FNType::get(ctx)
421 <<
" types are supported for conversions from f32x2 to f6x2.";
426LogicalResult ConvertF32x2ToF8x2Op::verify() {
427 using RndMode = NVVM::FPRoundingMode;
428 using SatMode = NVVM::SaturationMode;
430 bool isRoundingModeRN = getRnd() == RndMode::RN;
431 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
432 bool isRoundingModeRP = getRnd() == RndMode::RP;
433 bool isSatFinite = getSat() == SatMode::SATFINITE;
435 bool hasRelu = getRelu();
440 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
442 if (!isRoundingModeRN) {
443 return emitOpError(
"Only RN rounding mode is supported for "
444 "conversions from f32x2 to ")
445 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
446 << mlir::Float8E5M2Type::get(ctx) <<
" types";
449 return emitOpError(
"Only SATFINITE saturation mode is supported "
452 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
453 << mlir::Float8E5M2Type::get(ctx) <<
" types";
457 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
458 if (!(isRoundingModeRZ || isRoundingModeRP)) {
459 return emitOpError(
"Only RZ and RP rounding modes are supported for "
460 "conversions from f32x2 to ")
461 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
464 return emitOpError(
"relu not supported for conversions to ")
465 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
471 << mlir::Float8E4M3FNType::get(ctx) <<
", "
472 << mlir::Float8E5M2Type::get(ctx) <<
", and "
473 << mlir::Float8E8M0FNUType::get(ctx)
475 "supported for conversions from f32x2 to f8x2";
479LogicalResult ConvertF16x2ToF8x2Op::verify() {
482 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
484 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
485 << mlir::Float8E5M2Type::get(ctx)
486 <<
" types are supported for conversions from f16x2 to f8x2.";
491LogicalResult ConvertBF16x2ToF8x2Op::verify() {
492 using RndMode = NVVM::FPRoundingMode;
493 using SatMode = NVVM::SaturationMode;
495 bool isRoundingModeRN = getRnd() == RndMode::RN;
496 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
497 bool isRoundingModeRP = getRnd() == RndMode::RP;
498 bool isSatFinite = getSat() == SatMode::SATFINITE;
499 bool hasRelu = getRelu();
504 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
506 if (!isRoundingModeRN)
507 return emitOpError(
"Only RN rounding mode is supported for "
508 "conversions from bf16x2 to ")
509 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
510 << mlir::Float8E5M2Type::get(ctx) <<
" types";
512 return emitOpError(
"Only SATFINITE saturation mode is supported "
513 "for conversions from bf16x2 to ")
514 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
515 << mlir::Float8E5M2Type::get(ctx) <<
" types";
518 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
519 if (!(isRoundingModeRZ || isRoundingModeRP))
520 return emitOpError(
"Only RZ and RP rounding modes are supported for "
521 "conversions from bf16x2 to ")
522 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
524 return emitOpError(
"relu not supported for conversions to ")
525 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
529 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
534LogicalResult ConvertF32x2ToF4x2Op::verify() {
537 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
539 << mlir::Float4E2M1FNType::get(ctx)
540 <<
" type is supported for conversions from f32x2 to f4x2.";
545LogicalResult ConvertF8x2ToF16x2Op::verify() {
548 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
550 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
551 << mlir::Float8E5M2Type::get(ctx)
552 <<
" types are supported for conversions from f8x2 to f16x2.";
557LogicalResult ConvertF8x2ToBF16x2Op::verify() {
559 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
561 << mlir::Float8E8M0FNUType::get(ctx)
562 <<
" type is supported for conversions from f8x2 to bf16x2.";
567LogicalResult ConvertF6x2ToF16x2Op::verify() {
570 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
572 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
573 << mlir::Float6E3M2FNType::get(ctx)
574 <<
" types are supported for conversions from f6x2 to f16x2.";
579LogicalResult ConvertF4x2ToF16x2Op::verify() {
582 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
584 << mlir::Float4E2M1FNType::get(ctx)
585 <<
" type is supported for conversions from f4x2 to f16x2.";
590LogicalResult PermuteOp::verify() {
591 using Mode = NVVM::PermuteMode;
592 bool hasHi =
static_cast<bool>(getHi());
599 return emitError(
"mode '") << getMode() <<
"' requires 'hi' operand.";
607 << getMode() <<
"' does not accept 'hi' operand.";
622 static constexpr FPRoundingMode validRndModes[] = {
623 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
625 if (!llvm::is_contained(validRndModes, rnd)) {
627 "Only RN, RZ, and RS rounding modes are supported for "
628 "conversions from f32x2 to ")
632 if (rnd == FPRoundingMode::RS) {
633 if (!hasRandomBits) {
634 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
639 "random_bits not supported for RN and RZ rounding modes.");
646LogicalResult ConvertF32x2ToF16x2Op::verify() {
648 getRandomBits() ?
true :
false, *
this);
651LogicalResult ConvertF32x2ToBF16x2Op::verify() {
653 getRandomBits() ?
true :
false, *
this);
656LogicalResult ConvertF32x4ToF8x4Op::verify() {
659 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
661 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
662 << mlir::Float8E5M2Type::get(ctx)
663 <<
" types are supported for conversions from f32x4 to f8x4.";
668LogicalResult ConvertF32x4ToF6x4Op::verify() {
671 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
673 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
674 << mlir::Float6E3M2FNType::get(ctx)
675 <<
" types are supported for conversions from f32x4 to f6x4.";
680LogicalResult ConvertF32x4ToF4x4Op::verify() {
683 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
684 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
685 <<
" type is supported for conversions from "
691LogicalResult BulkStoreOp::verify() {
692 if (getInitVal() != 0)
693 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
697LogicalResult PMEventOp::verify() {
698 auto eventId = getEventId();
699 auto maskedEventId = getMaskedEventId();
700 if (!maskedEventId && !eventId) {
701 return emitOpError() <<
"either `id` or `mask` must be set";
704 if (maskedEventId && eventId) {
705 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
709 if (eventId < 0 || eventId > 15) {
710 return emitOpError() <<
"`id` must be between 0 and 15";
714 return llvm::success();
720std::optional<mlir::NVVM::MMATypes>
721MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
723 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
724 if (operandElType.
isF64())
725 return NVVM::MMATypes::f64;
726 if (operandElType.
isF16() || operandElType == half2Type)
727 return NVVM::MMATypes::f16;
728 if (operandElType.
isF32() && isAccumulator)
729 return NVVM::MMATypes::f32;
730 if (operandElType.
isF32() && !isAccumulator)
731 return NVVM::MMATypes::tf32;
732 if (llvm::isa<IntegerType>(operandElType)) {
734 return NVVM::MMATypes::s32;
738 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
739 if (structType.getBody().empty())
741 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
748 return (type == MMATypes::u4 || type == MMATypes::s4);
752 return (type == MMATypes::u8 || type == MMATypes::s8);
757 type == MMATypes::s32;
760MMATypes MmaOp::accumPtxType() {
761 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
762 getODSOperands(2).getTypes().front(),
true);
763 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
767MMATypes MmaOp::resultPtxType() {
768 std::optional<mlir::NVVM::MMATypes> val =
769 inferOperandMMAType(getResult().
getType(),
true);
770 assert(val.has_value() &&
"result PTX type should always be inferrable");
776 struct MMAOperandFragment {
777 StringRef operandName;
778 StringRef ptxTypeAttr;
779 SmallVector<Value, 4> regs;
780 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
781 : operandName(name), ptxTypeAttr(ptxTypeName) {}
784 std::array<MMAOperandFragment, 3> frags{
785 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
786 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
787 MMAOperandFragment(
"C",
"")};
789 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
791 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
792 auto &frag = frags[fragIdx];
793 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
794 for (
auto operandIdx = varOperandSpec.first;
795 operandIdx < varOperandSpec.first + varOperandSpec.second;
797 frag.regs.push_back(this->getOperand(operandIdx));
798 if (operandIdx == 0) {
799 regTypes.push_back(this->getOperand(operandIdx).
getType());
802 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
803 regTypes.back(), fragIdx >= 2);
805 ignoreAttrNames.push_back(frag.ptxTypeAttr);
808 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
809 p <<
" " << frag.operandName;
815 for (
const auto &frag : frags) {
816 printMmaOperand(frag);
825 frags[1].regs[0].getType(),
826 frags[2].regs[0].getType()},
835 std::optional<MMAIntOverflow> intOverflow,
836 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
837 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
839 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
844 result.addOperands(operandA);
845 result.addOperands(operandB);
846 result.addOperands(operandC);
848 if (multiplicandPtxTypes) {
849 result.addAttribute(
"multiplicandAPtxType",
850 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
851 result.addAttribute(
"multiplicandBPtxType",
852 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
854 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
855 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
856 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
857 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
860 if (multiplicandLayouts) {
861 result.addAttribute(
"layoutA",
862 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
863 result.addAttribute(
"layoutB",
864 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
866 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
867 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
870 if (intOverflow.has_value())
871 result.addAttribute(
"intOverflowBehavior",
872 MMAIntOverflowAttr::get(ctx, *intOverflow));
873 if (b1Op.has_value())
874 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
876 result.addTypes(resultType);
878 MmaOp::getOperandSegmentSizeAttr(),
880 static_cast<int32_t>(operandB.size()),
881 static_cast<int32_t>(operandC.size())}));
889 struct MMAOperandFragment {
890 std::optional<MMATypes> elemtype;
891 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
892 SmallVector<Type> regTypes;
896 std::array<MMAOperandFragment, 4> frags;
902 MMAOperandFragment &frag) -> LogicalResult {
932 if (operandTypes.size() != 3)
935 "expected one type for each operand segment but got " +
936 Twine(operandTypes.size()) +
" types");
937 for (
const auto &iter : llvm::enumerate(operandTypes)) {
938 auto &frag = frags[iter.index()];
939 frag.regTypes.resize(frag.regs.size(), iter.value());
943 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
950 frags[3].elemtype = inferOperandMMAType(resultType,
true);
952 std::array<StringRef, 2> names{
"multiplicandAPtxType",
953 "multiplicandBPtxType"};
954 for (
unsigned idx = 0; idx < names.size(); idx++) {
955 const auto &frag = frags[idx];
956 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
957 if (!frag.elemtype.has_value() && !attr.has_value()) {
960 "attribute " + names[idx] +
961 " is not provided explicitly and cannot be inferred");
963 if (!attr.has_value())
965 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
968 result.addTypes(resultType);
969 if (!namedAttributes.
empty())
970 result.addAttributes(namedAttributes);
971 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
973 static_cast<int32_t>(frags[0].regs.size()),
974 static_cast<int32_t>(frags[1].regs.size()),
975 static_cast<int32_t>(frags[2].regs.size()),
980LogicalResult MmaOp::verify() {
982 auto f16Ty = Float16Type::get(context);
983 auto i32Ty = IntegerType::get(context, 32);
984 auto f16x2Ty = VectorType::get(2, f16Ty);
985 auto f32Ty = Float32Type::get(context);
986 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
987 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
990 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
993 auto f16x2x2StructTy =
994 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
996 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
998 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1000 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1001 getShapeAttr().getK()};
1007 AllowedShapes allowedShapes;
1008 AllowedTypes expectedA;
1009 AllowedTypes expectedB;
1010 AllowedTypes expectedC;
1015 if (mmaShape[0] == 16) {
1017 Type multiplicandFragType;
1018 switch (*getMultiplicandAPtxType()) {
1019 case MMATypes::tf32:
1021 multiplicandFragType = i32Ty;
1022 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1023 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1025 case MMATypes::bf16:
1027 multiplicandFragType = i32Ty;
1028 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1029 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1033 multiplicandFragType = f16x2Ty;
1034 expectedResult.push_back(f16x2x2StructTy);
1035 expectedResult.push_back(f32x4StructTy);
1049 return emitError(
"invalid shape or multiplicand type: ")
1050 << getMultiplicandAPtxType().value();
1054 expectedResult.push_back(s32x4StructTy);
1055 expectedC.emplace_back(4, i32Ty);
1056 multiplicandFragType = i32Ty;
1058 expectedC.emplace_back(2, f16x2Ty);
1059 expectedC.emplace_back(4, f32Ty);
1062 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
1063 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1064 expectedA.emplace_back(unitA, multiplicandFragType);
1065 expectedB.emplace_back(unitB, multiplicandFragType);
1066 allowedShapes.push_back({16, 8, kFactor});
1067 allowedShapes.push_back({16, 8, kFactor * 2});
1069 if (resultPtxType() != accumPtxType())
1074 if (mmaShape[0] == 8) {
1075 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1076 expectedA.emplace_back(2, f16x2Ty);
1077 expectedB.emplace_back(2, f16x2Ty);
1078 expectedResult.push_back(f16x2x4StructTy);
1079 expectedResult.push_back(f32x8StructTy);
1080 expectedC.emplace_back(4, f16x2Ty);
1081 expectedC.emplace_back(8, f32Ty);
1082 allowedShapes.push_back({8, 8, 4});
1084 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1085 Type f64Ty = Float64Type::get(context);
1086 expectedA.emplace_back(1, f64Ty);
1087 expectedB.emplace_back(1, f64Ty);
1088 expectedC.emplace_back(2, f64Ty);
1089 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1091 allowedShapes.push_back({8, 8, 4});
1094 expectedA.push_back({i32Ty});
1095 expectedB.push_back({i32Ty});
1096 expectedC.push_back({i32Ty, i32Ty});
1097 expectedResult.push_back(s32x2StructTy);
1099 allowedShapes.push_back({8, 8, 32});
1101 allowedShapes.push_back({8, 8, 16});
1102 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1103 allowedShapes.push_back({8, 8, 128});
1107 std::string errorMessage;
1108 llvm::raw_string_ostream errorStream(errorMessage);
1111 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1112 !llvm::is_contained(allowedShapes, mmaShape)) {
1113 errorStream <<
"unimplemented variant for MMA shape <";
1114 llvm::interleaveComma(mmaShape, errorStream);
1120 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1121 for (
const auto &iter : llvm::enumerate(
1123 auto spec = this->getODSOperandIndexAndLength(iter.index());
1125 operand_type_begin() + spec.first +
1127 bool match = llvm::is_contained(iter.value(), operandTySeg);
1130 errorStream <<
"Could not match types for the "
1131 << operandNames[iter.index()]
1132 <<
" operands; expected one of ";
1133 for (
const auto &x : iter.value()) {
1134 errorStream << x.size() <<
"x" << x[0] <<
" ";
1136 errorStream <<
"but got ";
1137 llvm::interleaveComma(operandTySeg, errorStream);
1143 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1144 return expectedResultType == getResult().getType();
1147 <<
"Could not match allowed types for the result; expected one of ";
1148 llvm::interleaveComma(expectedResult, errorStream);
1149 errorStream <<
" but got " << getResult().getType();
1154 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1155 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1163 if (!getIntOverflowBehavior())
1165 getIntOverflowBehaviorAttrName().strref() +
1173 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1174 getMultiplicandAPtxType() == MMATypes::f16);
1176 if (!isM8N8K4_F16) {
1178 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1179 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1180 "layoutB = #nvvm.mma_layout<col> for shape <")
1181 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1182 <<
"> with element types " << *getMultiplicandAPtxType() <<
" and "
1183 << *getMultiplicandBPtxType()
1184 <<
". Only m8n8k4 with f16 supports other layouts.";
1191MMATypes MmaSpOp::accumPtxType() {
1192 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1193 getODSOperands(2).getTypes().front(),
true);
1194 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1198MMATypes MmaSpOp::resultPtxType() {
1199 std::optional<mlir::NVVM::MMATypes> val =
1200 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1201 assert(val.has_value() &&
"result PTX type should always be inferrable");
1207 llvm::IRBuilderBase &builder) {
1208 auto thisOp = cast<NVVM::MmaSpOp>(op);
1216 auto intId = MmaSpOp::getIntrinsicID(
1217 thisOp.getShape().getM(), thisOp.getShape().getN(),
1218 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1219 thisOp.getOrderedMetadata(), thisOp.getKind(),
1220 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1221 thisOp.accumPtxType(), thisOp.resultPtxType());
1223 return {intId, args};
1228 struct MMAOperandFragment {
1229 StringRef operandName;
1230 StringRef ptxTypeAttr;
1231 SmallVector<Value, 4> regs;
1232 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1233 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1236 std::array<MMAOperandFragment, 5> frags{
1237 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1238 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1239 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1240 MMAOperandFragment(
"selector",
"")};
1242 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1245 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1246 auto &frag = frags[fragIdx];
1247 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1248 for (
auto operandIdx = varOperandSpec.first;
1249 operandIdx < varOperandSpec.first + varOperandSpec.second;
1251 frag.regs.push_back(this->getOperand(operandIdx));
1252 if (operandIdx == varOperandSpec.first) {
1253 regTypes.push_back(this->getOperand(operandIdx).
getType());
1256 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1257 regTypes.back(), fragIdx >= 2);
1259 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1263 frags[3].regs.push_back(getSparseMetadata());
1264 frags[4].regs.push_back(getSparsitySelector());
1266 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1267 p <<
" " << frag.operandName;
1273 for (
const auto &frag : frags)
1274 printMmaSpOperand(frag);
1279 for (
int i = 0; i < 3; ++i) {
1284 p <<
") -> " << getResult().getType();
1291 std::optional<MMAIntOverflow> intOverflow,
1292 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1294 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1299 result.addOperands(operandA);
1300 result.addOperands(operandB);
1301 result.addOperands(operandC);
1302 result.addOperands(sparseMetadata);
1303 result.addOperands(sparsitySelector);
1305 if (multiplicandPtxTypes) {
1306 result.addAttribute(
"multiplicandAPtxType",
1307 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1308 result.addAttribute(
"multiplicandBPtxType",
1309 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1311 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1312 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1313 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1314 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1317 if (intOverflow.has_value())
1318 result.addAttribute(
"intOverflowBehavior",
1319 MMAIntOverflowAttr::get(ctx, *intOverflow));
1321 result.addTypes(resultType);
1323 MmaSpOp::getOperandSegmentSizeAttr(),
1325 static_cast<int32_t>(operandB.size()),
1326 static_cast<int32_t>(operandC.size()), 1,
1331 struct MMAOperandFragment {
1332 std::optional<MMATypes> elemtype;
1333 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1334 SmallVector<Type> regTypes;
1338 std::array<MMAOperandFragment, 6> frags;
1343 auto parseMmaSpOperand = [&](StringRef operandName,
1344 MMAOperandFragment &frag) -> LogicalResult {
1355 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1357 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1359 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1361 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1363 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1379 if (operandTypes.size() != 3)
1382 "expected one type for each operand segment but got " +
1383 Twine(operandTypes.size()) +
" types");
1384 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1385 auto &frag = frags[iter.index()];
1386 frag.regTypes.resize(frag.regs.size(), iter.value());
1391 MmaOp::inferOperandMMAType(frag.regTypes[0],
1399 MmaOp::inferOperandMMAType(resultType,
true);
1414 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1415 "multiplicandBPtxType"};
1416 for (
unsigned idx = 0; idx < names.size(); idx++) {
1417 const auto &frag = frags[idx];
1418 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1419 if (!frag.elemtype.has_value() && !attr.has_value()) {
1422 "attribute " + names[idx] +
1423 " is not provided explicitly and cannot be inferred");
1425 if (!attr.has_value())
1427 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1430 result.addTypes(resultType);
1431 if (!namedAttributes.
empty())
1432 result.addAttributes(namedAttributes);
1433 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1435 static_cast<int32_t>(frags[0].regs.size()),
1436 static_cast<int32_t>(frags[1].regs.size()),
1437 static_cast<int32_t>(frags[2].regs.size()),
1444LogicalResult MmaSpOp::verify() {
1446 auto f16Ty = Float16Type::get(context);
1447 auto i32Ty = IntegerType::get(context, 32);
1448 auto f16x2Ty = VectorType::get(2, f16Ty);
1449 auto f32Ty = Float32Type::get(context);
1450 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1451 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1453 auto s32x4StructTy =
1454 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1455 auto f32x8StructTy =
1457 auto f16x2x2StructTy =
1458 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1459 auto f32x4StructTy =
1460 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1461 auto s32x2StructTy =
1462 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1464 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1465 getShapeAttr().getK()};
1471 AllowedShapes allowedShapes;
1472 AllowedTypes expectedA;
1473 AllowedTypes expectedB;
1474 AllowedTypes expectedC;
1479 if (mmaShape[0] == 16) {
1481 Type multiplicandFragType;
1482 switch (*getMultiplicandAPtxType()) {
1483 case MMATypes::tf32:
1485 multiplicandFragType = i32Ty;
1486 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1487 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1489 allowedShapes.push_back({16, 8, 8});
1490 allowedShapes.push_back({16, 8, 16});
1492 case MMATypes::bf16:
1494 multiplicandFragType = i32Ty;
1495 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1496 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1498 allowedShapes.push_back({16, 8, 16});
1499 allowedShapes.push_back({16, 8, 32});
1503 multiplicandFragType = f16x2Ty;
1504 expectedResult.push_back(f16x2x2StructTy);
1505 expectedResult.push_back(f32x4StructTy);
1507 allowedShapes.push_back({16, 8, 16});
1508 allowedShapes.push_back({16, 8, 32});
1514 allowedShapes.push_back({16, 8, 64});
1515 allowedShapes.push_back({16, 8, 128});
1521 allowedShapes.push_back({16, 8, 32});
1522 allowedShapes.push_back({16, 8, 64});
1524 case MMATypes::e4m3:
1525 case MMATypes::e5m2:
1526 case MMATypes::e3m2:
1527 case MMATypes::e2m3:
1528 case MMATypes::e2m1:
1530 multiplicandFragType = i32Ty;
1531 expectedResult.push_back(f16x2x2StructTy);
1532 expectedResult.push_back(f32x4StructTy);
1534 allowedShapes.push_back({16, 8, 64});
1537 return emitError(
"invalid shape or multiplicand type: ")
1538 << getMultiplicandAPtxType().value();
1542 expectedResult.push_back(s32x4StructTy);
1543 expectedC.emplace_back(4, i32Ty);
1544 multiplicandFragType = i32Ty;
1545 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1546 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1548 expectedC.emplace_back(2, f16x2Ty);
1549 expectedC.emplace_back(4, f32Ty);
1551 expectedC.emplace_back(2, f16x2Ty);
1552 expectedC.emplace_back(4, f32Ty);
1557 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1558 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1559 expectedA.emplace_back(unitA, multiplicandFragType);
1560 expectedB.emplace_back(unitB, multiplicandFragType);
1562 if (resultPtxType() != accumPtxType())
1567 if (mmaShape[0] == 8) {
1568 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1569 expectedA.emplace_back(2, f16x2Ty);
1570 expectedB.emplace_back(2, f16x2Ty);
1571 expectedResult.push_back(f16x2x4StructTy);
1572 expectedResult.push_back(f32x8StructTy);
1573 expectedC.emplace_back(4, f16x2Ty);
1574 expectedC.emplace_back(8, f32Ty);
1575 allowedShapes.push_back({8, 8, 4});
1577 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1578 Type f64Ty = Float64Type::get(context);
1579 expectedA.emplace_back(1, f64Ty);
1580 expectedB.emplace_back(1, f64Ty);
1581 expectedC.emplace_back(2, f64Ty);
1582 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1584 allowedShapes.push_back({8, 8, 4});
1587 expectedA.push_back({i32Ty});
1588 expectedB.push_back({i32Ty});
1589 expectedC.push_back({i32Ty, i32Ty});
1590 expectedResult.push_back(s32x2StructTy);
1592 allowedShapes.push_back({8, 8, 32});
1594 allowedShapes.push_back({8, 8, 16});
1598 std::string errorMessage;
1599 llvm::raw_string_ostream errorStream(errorMessage);
1602 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1603 !llvm::is_contained(allowedShapes, mmaShape)) {
1604 errorStream <<
"unimplemented variant for MMA shape <";
1605 llvm::interleaveComma(mmaShape, errorStream);
1611 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1612 for (
const auto &iter : llvm::enumerate(
1614 auto spec = this->getODSOperandIndexAndLength(iter.index());
1616 operand_type_begin() + spec.first +
1618 bool match = llvm::is_contained(iter.value(), operandTySeg);
1621 errorStream <<
"Could not match types for the "
1622 << operandNames[iter.index()]
1623 <<
" operands; expected one of ";
1624 for (
const auto &x : iter.value()) {
1625 errorStream << x.size() <<
"x" << x[0] <<
" ";
1627 errorStream <<
"but got ";
1628 llvm::interleaveComma(operandTySeg, errorStream);
1634 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1635 return expectedResultType == getResult().getType();
1638 <<
"Could not match allowed types for the result; expected one of ";
1639 llvm::interleaveComma(expectedResult, errorStream);
1640 errorStream <<
" but got " << getResult().getType();
1648 if (!getIntOverflowBehavior())
1650 getIntOverflowBehaviorAttrName().strref() +
1655 if (!getSparseMetadata().
getType().isInteger(32)) {
1656 return emitOpError() <<
"sparse metadata must be i32 type";
1660 if (!getSparsitySelector().
getType().isInteger(32)) {
1661 return emitOpError() <<
"sparsity selector must be i32 type";
1673struct MMAOperandFragment {
1674 StringRef operandName;
1675 StringRef ptxTypeAttr;
1676 SmallVector<Value, 4> regs;
1677 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1678 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1685 p <<
" " << name <<
"[";
1704template <
typename Op>
1709 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1710 auto &frag = frags[fragIdx];
1711 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1712 for (
auto operandIdx = varOperandSpec.first;
1713 operandIdx < varOperandSpec.first + varOperandSpec.second;
1715 frag.regs.push_back(op.getOperand(operandIdx));
1716 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1717 regTypes.push_back(op.getOperand(operandIdx).getType());
1721 regTypes.push_back(frag.regs[0].getType());
1723 std::optional<MMATypes> inferredType =
1724 MmaOp::inferOperandMMAType(regTypes.back(),
1727 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1738 auto typeParser = [&]() {
1742 operandTypes.push_back(ty);
1748 if (operandTypes.size() != 3)
1750 "expected exactly 3 types");
1759 if (!attrs.
get(
"multiplicandAPtxType")) {
1760 if (
auto inferredType =
1761 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1762 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1765 if (!attrs.
get(
"multiplicandBPtxType")) {
1766 if (
auto inferredType =
1767 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1768 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1774template <
typename OpType>
1777 ScaleVecSize scaleVecSize,
1778 BlockScaleFormat blockScaleFormat,
1779 MMABlockScaleKind kind) {
1781 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1782 properties.setShape(
1784 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1785 properties.setBlockScaleFormat(
1786 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1787 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1794 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1795 if (multiplicandPtxTypes) {
1796 result.addAttribute(
"multiplicandAPtxType",
1797 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1798 result.addAttribute(
"multiplicandBPtxType",
1799 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1801 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1802 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1803 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1804 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1809template <
typename OpTy>
1811 return *MmaOp::inferOperandMMAType(
1812 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1822 std::array<MMAOperandFragment, 3> frags{
1823 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1824 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1825 MMAOperandFragment(
"C",
"")};
1827 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1832 for (
const auto &frag : frags)
1837 {getScaleAData(), getByteIdA(), getThreadIdA()});
1839 {getScaleBData(), getByteIdB(), getThreadIdB()});
1846 frags[1].regs[0].getType(),
1847 frags[2].regs[0].getType()},
1853ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1855 struct LocalOperandFragment {
1856 std::optional<MMATypes> elemtype;
1857 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1861 std::array<LocalOperandFragment, 3> frags;
1890 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1891 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1894 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1904 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1914 result.addAttributes(namedAttributes);
1918 result.addTypes(resultTypes);
1919 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1921 static_cast<int32_t>(frags[0].regs.size()),
1922 static_cast<int32_t>(frags[1].regs.size()),
1923 static_cast<int32_t>(frags[2].regs.size()),
1934void MmaBlockScaleOp::build(
1939 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1940 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1941 MMABlockScaleKind kind) {
1942 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1945 blockScaleFormat, kind);
1947 result.addOperands(operandA);
1948 result.addOperands(operandB);
1949 result.addOperands(operandC);
1951 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1954 multiplicandPtxTypes);
1956 result.addTypes(resultType);
1957 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1959 static_cast<int32_t>(operandA.size()),
1960 static_cast<int32_t>(operandB.size()),
1961 static_cast<int32_t>(operandC.size()),
1973 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1977 for (
Value operand : curOp.getOperandA())
1979 for (
Value operand : curOp.getOperandB())
1981 for (
Value operand : curOp.getOperandC())
1985 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1986 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1987 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1988 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1989 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1990 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1992 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1993 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1994 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1996 curOp.getBlockScaleFormat(), curOp.getKind());
1998 return {intId, args};
2001LogicalResult MmaBlockScaleOp::verify() {
2007 if (m == 16 && n == 8 && k == 64) {
2008 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2009 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2011 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
2012 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2013 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2015 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
2016 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2018 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
2019 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2020 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2021 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2022 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2023 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2024 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2026 "attributes for mma.m16n8k64.mxf4nvf4");
2030 }
else if (m == 16 && n == 8 && k == 32) {
2031 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2032 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2033 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2035 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2036 "attributes for mma.m16n8k32");
2049 std::array<MMAOperandFragment, 3> frags{
2050 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
2051 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
2052 MMAOperandFragment(
"C",
"")};
2054 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
2059 for (
const auto &frag : frags)
2068 {getScaleAData(), getByteIdA(), getThreadIdA()});
2070 {getScaleBData(), getByteIdB(), getThreadIdB()});
2077 frags[1].regs[0].getType(),
2078 frags[2].regs[0].getType()},
2084ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
2086 struct LocalOperandFragment {
2087 std::optional<MMATypes> elemtype;
2088 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
2092 std::array<LocalOperandFragment, 3> frags;
2128 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2129 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2132 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2141 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2154 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2164 result.addAttributes(namedAttributes);
2169 if (!
result.attributes.get(
"orderedMetadata"))
2172 result.addTypes(resultTypes);
2173 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2175 static_cast<int32_t>(frags[0].regs.size()),
2176 static_cast<int32_t>(frags[1].regs.size()),
2177 static_cast<int32_t>(frags[2].regs.size()),
2190void MmaSpBlockScaleOp::build(
2196 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2197 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2198 MMABlockScaleKind kind) {
2199 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2202 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2205 result.addOperands(operandA);
2206 result.addOperands(operandB);
2207 result.addOperands(operandC);
2208 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2209 threadIdA, scaleBData, byteIdB, threadIdB});
2212 multiplicandPtxTypes);
2214 result.addTypes(resultType);
2215 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2217 static_cast<int32_t>(operandA.size()),
2218 static_cast<int32_t>(operandB.size()),
2219 static_cast<int32_t>(operandC.size()),
2233 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2237 for (
Value operand : curOp.getOperandA())
2239 for (
Value operand : curOp.getOperandB())
2241 for (
Value operand : curOp.getOperandC())
2245 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2246 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2249 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2250 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2251 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2252 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2253 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2254 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2256 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2257 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2258 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2260 curOp.getBlockScaleFormat(), curOp.getKind());
2262 return {intId, args};
2265LogicalResult MmaSpBlockScaleOp::verify() {
2267 if (!getOrderedMetadata()) {
2268 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2276 if (m == 16 && n == 8 && k == 128) {
2277 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2278 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2280 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2281 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2282 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2284 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2285 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2287 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2288 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2289 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2290 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2291 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2292 (getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3 ||
2293 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))))
2295 "attributes for mma.m16n8k128.mxf4nvf4");
2299 }
else if (m == 16 && n == 8 && k == 64) {
2300 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2301 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2302 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2304 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2305 "attributes for mma.m16n8k64");
2312LogicalResult ShflOp::verify() {
2313 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2315 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2316 Type actualType) -> LogicalResult {
2317 return emitOpError(
"expected " + desc +
" to be of type ")
2318 << expectedType <<
" but got " << actualType <<
" instead";
2321 if (returnStructType) {
2322 if (!getReturnValueAndIsValid())
2323 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2324 "specified when the return type is a struct type");
2326 if (returnStructType.getBody().size() != 2)
2327 return emitOpError(
"expected return type to be a two-element struct");
2330 auto resultType = returnStruct[0];
2331 if (resultType != getVal().
getType())
2332 return verifyTypeError(
"first element in the returned struct",
2333 getVal().
getType(), resultType);
2335 auto predicateType = returnStruct[1];
2336 if (!predicateType.isInteger(1))
2337 return verifyTypeError(
"second element in the returned struct",
2341 if (getReturnValueAndIsValid())
2342 return emitOpError(
"expected return type to be a two-element struct");
2345 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2351ShflOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
2352 ShflOp::Adaptor adaptor,
2354 Type valType = adaptor.getVal().getType();
2355 if (adaptor.getReturnValueAndIsValid())
2356 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2357 context, {valType, IntegerType::get(context, 1)}));
2359 inferredReturnTypes.push_back(valType);
2364 NVVM::MMAFrag frag,
int nRow,
2367 unsigned numberElements = 0;
2370 Type f16x2 = VectorType::get(2, builder.getF16Type());
2371 if (type == NVVM::MMATypes::f16) {
2372 elementType = f16x2;
2373 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2377 }
else if (type == NVVM::MMATypes::f32) {
2378 elementType = builder.getF32Type();
2380 }
else if (type == NVVM::MMATypes::f64) {
2381 elementType = builder.getF64Type();
2382 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2386 }
else if (type == NVVM::MMATypes::tf32) {
2387 elementType = builder.getI32Type();
2389 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2390 elementType = builder.getI32Type();
2391 int parallelSize = 0;
2392 if (frag == NVVM::MMAFrag::a)
2393 parallelSize = nRow;
2394 if (frag == NVVM::MMAFrag::b)
2395 parallelSize = nCol;
2398 if (parallelSize == 16)
2401 else if (parallelSize == 8)
2403 else if (parallelSize == 32)
2405 }
else if (type == NVVM::MMATypes::s32) {
2406 elementType = builder.getI32Type();
2409 assert(numberElements != 0 && elementType !=
nullptr);
2410 return std::make_pair(elementType, numberElements);
2413static std::pair<mlir::Type, unsigned>
2417 if (frag == NVVM::MMAFrag::a) {
2420 }
else if (frag == NVVM::MMAFrag::b) {
2427 assert(nRow && nCol);
2431LogicalResult NVVM::WMMALoadOp::verify() {
2432 unsigned addressSpace =
2433 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2434 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2435 addressSpace != NVVMMemorySpace::Shared)
2436 return emitOpError(
"expected source pointer in memory "
2439 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2440 getEltype(), getFrag()) == 0)
2441 return emitOpError() <<
"invalid attribute combination";
2446 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2448 return emitOpError(
"expected destination type to be f64");
2452 Type dstType = LLVM::LLVMStructType::getLiteral(
2455 return emitOpError(
"expected destination type is a structure of ")
2456 << typeInfo.second <<
" elements of type " << typeInfo.first;
2460LogicalResult NVVM::WMMAStoreOp::verify() {
2461 unsigned addressSpace =
2462 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2463 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2464 addressSpace != NVVMMemorySpace::Shared)
2465 return emitOpError(
"expected operands to be a source pointer in memory "
2468 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2470 return emitOpError() <<
"invalid attribute combination";
2473 if (getArgs().size() != typeInfo.second)
2474 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2475 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2476 return operands.
getType() != typeInfo.first;
2478 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2482LogicalResult NVVM::WMMAMmaOp::verify() {
2483 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2484 getLayoutB(), getEltypeA(),
2486 return emitOpError() <<
"invalid attribute combination";
2494 arguments.append(typeInfoA.second, typeInfoA.first);
2495 arguments.append(typeInfoB.second, typeInfoB.first);
2496 arguments.append(typeInfoC.second, typeInfoC.first);
2497 unsigned numArgs = arguments.size();
2498 if (getArgs().size() != numArgs)
2499 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2500 for (
unsigned i = 0; i < numArgs; i++) {
2501 if (getArgs()[i].
getType() != arguments[i])
2502 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2505 Type dstType = LLVM::LLVMStructType::getLiteral(
2508 return emitOpError(
"expected destination type is a structure of ")
2509 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2513LogicalResult NVVM::LdMatrixOp::verify() {
2515 if (m == 8 && n == 8) {
2516 if (num != 1 && num != 2 && num != 4) {
2517 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2520 if (getEltType() != LdStMatrixEltType::B16) {
2521 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2523 }
else if (m == 8 && n == 16) {
2524 if (num != 1 && num != 2 && num != 4) {
2525 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2528 if (getLayout() != MMALayout::row) {
2529 return emitOpError(
"expected layout to be row for 8x16 matrix");
2531 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2532 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2533 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2534 "b8x16.b6x16_p32 for 8x16 matrix");
2536 }
else if (m == 16 && n == 16) {
2537 if (num != 1 && num != 2) {
2538 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2541 if (getLayout() != MMALayout::col) {
2542 return emitOpError(
"expected layout to be col for 16x16 matrix");
2544 if (getEltType() != LdStMatrixEltType::B8 &&
2545 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2546 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2547 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2548 "b8x16.b6x16_p32 for 16x16 matrix");
2551 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2555 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2556 if (numElements == 1 &&
getType() != i32)
2557 return emitOpError(
"expected destination type is i32");
2558 if (numElements == 2 || numElements == 4) {
2559 Type dstType = LLVM::LLVMStructType::getLiteral(
2562 return emitOpError(
"expected destination type is a structure of ")
2563 << numElements <<
" elements of type i32";
2569LogicalResult LdMatrixOp::inferReturnTypes(
2570 MLIRContext *context, std::optional<Location> location,
2572 uint32_t num = adaptor.getNum();
2573 uint32_t m = adaptor.getShape().getM();
2574 uint32_t n = adaptor.getShape().getN();
2575 uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
2577 Type i32 = IntegerType::get(context, 32);
2578 if (numElements == 1)
2579 inferredReturnTypes.push_back(i32);
2581 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
2586LogicalResult NVVM::StMatrixOp::verify() {
2587 int numMatrix = getSources().size();
2588 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2589 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2592 if (m == 8 && n == 8) {
2593 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2594 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2596 }
else if (m == 16 && n == 8) {
2597 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2598 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2600 if (getLayout() != NVVM::MMALayout::col) {
2601 return emitOpError(
"expected layout to be col for 16x8 matrix");
2604 return emitOpError(
"expected shape to be 8x8 or 16x8");
2610LogicalResult NVVM::MovMatrixOp::verify() {
2612 if (m != 8 || n != 8)
2614 if (getLayout() != NVVM::MMALayout::col)
2616 if (getEltType() != NVVM::LdStMatrixEltType::B16)
2617 return emitOpError(
"expected element type to be b16");
2622 if (typeA == NVVM::WGMMATypes::tf32)
2624 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2626 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2628 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2630 if (typeA == NVVM::WGMMATypes::b1)
2636 NVVM::WGMMATypes typeA,
2637 NVVM::WGMMATypes typeB) {
2639 case NVVM::WGMMATypes::f16:
2640 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2641 typeB == NVVM::WGMMATypes::f16)
2644 case NVVM::WGMMATypes::tf32:
2645 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2648 case NVVM::WGMMATypes::u8:
2649 case NVVM::WGMMATypes::s8:
2650 if (typeD == NVVM::WGMMATypes::s32 &&
2651 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2654 case NVVM::WGMMATypes::b1:
2655 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2658 case NVVM::WGMMATypes::bf16:
2659 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2660 typeB == NVVM::WGMMATypes::bf16)
2663 case NVVM::WGMMATypes::e4m3:
2664 case NVVM::WGMMATypes::e5m2:
2665 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2666 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2669 case WGMMATypes::f32:
2670 case WGMMATypes::s32:
2671 llvm_unreachable(
"unsupported input types");
2679 72, 80, 88, 96, 104, 112, 120, 128,
2680 136, 144, 152, 160, 168, 176, 184, 192,
2681 200, 208, 216, 224, 232, 240, 248, 256};
2683 80, 96, 112, 128, 144, 160,
2684 176, 192, 208, 224, 240, 256};
2686 case WGMMATypes::f16:
2687 case WGMMATypes::tf32:
2688 case WGMMATypes::bf16:
2689 case WGMMATypes::e4m3:
2690 case WGMMATypes::e5m2:
2691 if (llvm::is_contained(allowedN, sizeN))
2694 case WGMMATypes::u8:
2695 case WGMMATypes::s8:
2696 case WGMMATypes::b1:
2697 if (llvm::is_contained(allowedNshort, sizeN))
2700 case WGMMATypes::f32:
2701 case WGMMATypes::s32:
2702 llvm_unreachable(
"unsupported input types");
2708LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2709 Value outValue = getResults();
2710 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2712 return emitOpError() <<
"expected results to be struct";
2713 int outputSize = stype.getBody().size();
2714 WGMMATypes typeD = getTypeD();
2715 WGMMATypes typeA = getTypeA();
2716 WGMMATypes typeB = getTypeB();
2718 for (
Type t : stype.getBody()) {
2719 if (t != stype.getBody().front())
2721 <<
"all elements in struct must be same type but there is " << t;
2724 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2725 typeD != WGMMATypes::s32) {
2726 return emitOpError() <<
"does not support the given output type " << typeD;
2728 if (typeD == WGMMATypes::s32 &&
2729 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2730 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2734 return emitOpError() << typeD <<
" += " << typeA <<
" * " << typeB
2735 <<
", it is not supported.";
2745 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2746 <<
" for input type " << typeA;
2750 return emitOpError() <<
"has input type " << typeA <<
" n is set to "
2751 <<
getShape().getN() <<
", it is not supported.";
2758 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2759 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2760 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2762 <<
"given layouts layout_a = " << getLayoutA()
2763 <<
" and layout_b = " << getLayoutB() <<
" for input types " << typeA
2765 <<
" requires transpose. However, this is only supported for: "
2766 << MMATypes::f16 <<
" and " << MMATypes::bf16;
2770 int expectedOutput = 0;
2771 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2772 expectedOutput =
getShape().getN() / 2;
2773 if (typeD == WGMMATypes::f16)
2774 expectedOutput =
getShape().getN() / 4;
2775 if (outputSize != expectedOutput) {
2776 return emitOpError() <<
"results " << expectedOutput
2777 <<
", however output struct has " << outputSize
2781 if (typeD != WGMMATypes::s32 &&
2782 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2783 NVVM::MMAIntOverflow::satfinite) {
2785 <<
" `satfinite` can be only used with s32 accumulator, however "
2786 "the current accumulator is "
2793std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2796 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2798 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2800 int expectedOutputRegisters = 0;
2801 if (getTypeD() == WGMMATypes::f16)
2802 expectedOutputRegisters =
getShape().getN() / 4;
2804 expectedOutputRegisters =
getShape().getN() / 2;
2807 llvm::raw_string_ostream ss(ptx);
2812 << ((expectedOutputRegisters * 2) + 2)
2814 "wgmma.mma_async.sync.aligned.m"
2815 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"." << getTypeA()
2816 <<
"." << getTypeB();
2817 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2818 NVVM::MMAIntOverflow::satfinite)
2822 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2823 ss <<
"$" << regCnt;
2824 if (regCnt != expectedOutputRegisters - 1)
2830 regCnt = (regCnt * 2);
2831 ss <<
" $" << (regCnt) <<
","
2832 <<
" $" << (regCnt + 1) <<
","
2834 if (getTypeD() != WGMMATypes::s32) {
2835 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2839 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2846bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2850 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2857 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2859 if (getTypeD() != WGMMATypes::s32) {
2860 asmValues.push_back(
2861 {makeConstantI32(rewriter,
2862 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2864 asmValues.push_back(
2865 {makeConstantI32(rewriter,
2866 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2870 asmValues.push_back(
2871 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2873 asmValues.push_back(
2874 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2880LogicalResult NVVM::FenceProxyOp::verify() {
2881 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2882 return emitOpError() <<
"async_shared fence requires space attribute";
2884 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2885 return emitOpError() <<
"only async_shared fence can have space attribute";
2890LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2891 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2892 return emitOpError(
"uni-directional proxies only support generic for "
2893 "from_proxy attribute");
2895 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2896 return emitOpError(
"uni-directional proxies only support tensormap "
2897 "for to_proxy attribute");
2901LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2902 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2903 return emitOpError(
"uni-directional proxies only support generic for "
2904 "from_proxy attribute");
2906 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2907 return emitOpError(
"uni-directional proxies only support tensormap "
2908 "for to_proxy attribute");
2912LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2913 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2914 return emitOpError(
"only generic is support for from_proxy attribute");
2916 if (getToProxy() != NVVM::ProxyKind::async)
2917 return emitOpError(
"only async is supported for to_proxy attribute");
2921LogicalResult NVVM::SetMaxRegisterOp::verify() {
2922 if (getRegCount() % 8)
2923 return emitOpError(
"new register size must be multiple of 8");
2924 if (getRegCount() < 24 || getRegCount() > 256)
2925 return emitOpError(
"new register size must be in between 24 to 256");
2929LogicalResult NVVM::BarrierOp::verify() {
2930 if (getNumberOfThreads() && !getBarrierId())
2932 "barrier id is missing, it should be set between 0 to 15");
2934 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
2935 return emitOpError(
"reduction are only available when id is 0");
2939 return emitOpError(
"reduction predicate and reduction operation must be "
2940 "specified together");
2945LogicalResult BarrierOp::inferReturnTypes(
2946 MLIRContext *context, std::optional<Location> location,
2948 if (adaptor.getReductionOp())
2949 inferredReturnTypes.push_back(IntegerType::get(context, 32));
2957LogicalResult NVVM::Tcgen05CpOp::verify() {
2958 auto mc = getMulticast();
2960 using SH = Tcgen05CpShape;
2961 using MC = Tcgen05CpMulticast;
2963 case SH::SHAPE_128x256b:
2964 case SH::SHAPE_128x128b:
2965 case SH::SHAPE_4x256b:
2967 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2969 case SH::SHAPE_64x128b:
2970 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2971 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2972 "warpx2_02_13 for tcgen05.cp Op");
2974 case SH::SHAPE_32x128b:
2975 if (mc != MC::WARPX4)
2977 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2983LogicalResult NVVM::MatchSyncOp::verify() {
2984 if (getKind() == NVVM::MatchSyncKind::all) {
2985 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2986 if (!type || type.getBody().size() != 2 ||
2987 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2988 return emitOpError(
"match.sync 'all' returns a two element struct with "
2989 "first element as i32 and second element as i1");
2992 if (!
getType().isInteger(32)) {
2993 return emitOpError(
"match.sync 'any' returns an i32");
2999LogicalResult MatchSyncOp::inferReturnTypes(
3000 MLIRContext *context, std::optional<Location> location,
3002 if (adaptor.getKind() == NVVM::MatchSyncKind::all)
3003 inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
3005 {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
3007 inferredReturnTypes.push_back(IntegerType::get(context, 32));
3011LogicalResult NVVM::VoteSyncOp::verify() {
3012 if (getKind() == NVVM::VoteSyncKind::ballot) {
3013 if (!
getType().isInteger(32)) {
3014 return emitOpError(
"vote.sync 'ballot' returns an i32");
3017 if (!
getType().isInteger(1)) {
3018 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
3024LogicalResult VoteSyncOp::inferReturnTypes(
3025 MLIRContext *context, std::optional<Location> location,
3027 unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
3028 inferredReturnTypes.push_back(IntegerType::get(context, width));
3032LogicalResult NVVM::PrefetchOp::verify() {
3033 using MemSpace = NVVM::NVVMMemorySpace;
3034 using CacheLevel = NVVM::PrefetchCacheLevel;
3036 unsigned addressSpace =
3037 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
3038 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
3039 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
3041 if (getTensormap() && cacheLevel)
3042 return emitOpError(
"cannot specify both tensormap and cache level");
3044 if (getTensormap()) {
3045 if (addressSpace != MemSpace::Generic &&
3046 addressSpace != MemSpace::Constant) {
3048 "prefetch tensormap requires a generic or constant pointer");
3051 if (evictPriority) {
3053 "prefetch tensormap does not support eviction priority");
3056 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
3058 "in_param_space can only be specified for a generic pointer");
3061 }
else if (cacheLevel) {
3062 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
3063 addressSpace != MemSpace::Local) {
3064 return emitOpError(
"prefetch to cache level requires a generic, global, "
3065 "or local pointer");
3069 if (*cacheLevel != CacheLevel::L1) {
3071 "unsupported cache level, the only supported uniform "
3072 "cache level is L1");
3075 if (addressSpace != MemSpace::Generic) {
3077 "prefetch to uniform cache requires a generic pointer");
3081 if (evictPriority) {
3082 if (*cacheLevel != CacheLevel::L2)
3084 "cache eviction priority supported only for cache level L2");
3086 if (addressSpace != MemSpace::Global)
3087 return emitOpError(
"cache eviction priority requires a global pointer");
3089 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
3090 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
3092 "unsupported cache eviction priority, only evict_last and "
3093 "evict_normal are supported");
3097 return emitOpError(
"predicate supported only on prefetch tensormap");
3101 "requires specification of either cache level or tensormap");
3107LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
3108 switch (getQueryType()) {
3109 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3111 return emitOpError(
"is_canceled query type returns an i1");
3113 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3114 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3115 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3116 if (!
getType().isInteger(32)) {
3117 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
3118 "get_first_cta_id_z query types return an i32");
3125LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
3126 MLIRContext *context, std::optional<Location> location,
3127 ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
3130 adaptor.getQueryType() == NVVM::ClusterLaunchControlQueryType::IS_CANCELED
3133 inferredReturnTypes.push_back(IntegerType::get(context, width));
3137LogicalResult NVVM::ReduxOp::verify() {
3140 if (!reduxType.
isF32()) {
3142 return emitOpError(
"abs attribute is supported only for f32 type");
3144 return emitOpError(
"nan attribute is supported only for f32 type");
3147 NVVM::ReductionKind kind = getKind();
3149 case NVVM::ReductionKind::ADD:
3150 case NVVM::ReductionKind::AND:
3151 case NVVM::ReductionKind::OR:
3152 case NVVM::ReductionKind::XOR:
3153 case NVVM::ReductionKind::MAX:
3154 case NVVM::ReductionKind::MIN:
3155 case NVVM::ReductionKind::UMAX:
3156 case NVVM::ReductionKind::UMIN:
3159 << kind <<
"' reduction kind unsupported with " << reduxType
3160 <<
" type. Only supported type is 'i32'.";
3162 case NVVM::ReductionKind::FMIN:
3163 case NVVM::ReductionKind::FMAX:
3164 if (!reduxType.isF32())
3166 << kind <<
"' reduction kind unsupported with " << reduxType
3167 <<
" type. Only supported type is 'f32'.";
3174LogicalResult NVVM::TensormapReplaceOp::verify() {
3175 auto ord = getOrd();
3176 Value newVal = getNewValue();
3177 auto newValAttr = getNewValueAttr();
3178 auto fieldName = stringifyEnum(getField());
3180 if (ord && !llvm::is_contained({NVVM::TensormapField::BOX_DIM,
3181 NVVM::TensormapField::GLOBAL_DIM,
3182 NVVM::TensormapField::GLOBAL_STRIDE,
3183 NVVM::TensormapField::ELEMENT_STRIDE},
3185 return emitOpError(
"ordinal is not supported for ")
3186 << fieldName <<
" field";
3188 auto invalidNewVal = [&](llvm::Twine type) -> std::string {
3189 return llvm::Twine(
"new_value must be specified and must be an " + type +
3190 " for " + llvm::Twine(fieldName) +
" field")
3194 auto invalidNewValAttr = [&]() -> std::string {
3195 return (llvm::Twine(
3196 "new_value_attr must be specified and must be a valid ") +
3197 llvm::Twine(fieldName) +
" attribute for " + fieldName +
" field")
3201 switch (getField()) {
3202 case NVVM::TensormapField::GLOBAL_ADDRESS:
3206 case NVVM::TensormapField::RANK:
3210 case NVVM::TensormapField::GLOBAL_STRIDE:
3212 return emitOpError(
"ordinal is required for global_stride field");
3216 case NVVM::TensormapField::BOX_DIM:
3217 case NVVM::TensormapField::GLOBAL_DIM:
3218 case NVVM::TensormapField::ELEMENT_STRIDE:
3221 << stringifyEnum(getField()) <<
" field";
3225 case NVVM::TensormapField::ELEMTYPE:
3226 if (!(newValAttr && llvm::isa<TensormapElemtypeAttr>(*newValAttr)))
3229 case NVVM::TensormapField::INTERLEAVE_LAYOUT:
3230 if (!(newValAttr && llvm::isa<TensormapInterleaveLayoutAttr>(*newValAttr)))
3233 case NVVM::TensormapField::SWIZZLE_MODE:
3234 if (!(newValAttr && llvm::isa<TensormapSwizzleModeAttr>(*newValAttr)))
3237 case NVVM::TensormapField::SWIZZLE_ATOMICITY:
3238 if (!(newValAttr && llvm::isa<TensormapSwizzleAtomicityAttr>(*newValAttr)))
3241 case NVVM::TensormapField::FILL_MODE:
3242 if (!(newValAttr && llvm::isa<TensormapFillModeAttr>(*newValAttr)))
3250template <
typename OpType>
3252 mlir::NVVM::FPRoundingMode rndMode = op.getRnd();
3253 mlir::NVVM::SaturationMode satMode = op.getSat();
3254 bool isFTZ = op.getFtz();
3257 mlir::Type opBaseType = isa<VectorType>(opType)
3258 ? cast<VectorType>(opType).getElementType()
3261 if (opBaseType.
isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3262 return op.emitOpError(
"FTZ and saturation are not supported for "
3263 "additions/subtractions involving f64 type");
3265 if (opBaseType.
isF16() && !(rndMode == NVVM::FPRoundingMode::RN ||
3266 rndMode == NVVM::FPRoundingMode::NONE))
3267 return op.emitOpError(
"only RN rounding mode is supported for f16 and "
3268 "vector<2xf16> additions/subtractions");
3270 if (opBaseType.
isBF16()) {
3271 if (rndMode != NVVM::FPRoundingMode::RN &&
3272 rndMode != NVVM::FPRoundingMode::NONE)
3273 return op.emitOpError(
"only RN rounding mode is supported for bf16 and "
3274 "vector<2xbf16> additions/subtractions");
3275 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3276 return op.emitOpError(
"FTZ and saturation are not supported for bf16 and "
3277 "vector<2xbf16> additions/subtractions");
3284 if (opBaseType.
isF16() && isFTZ && satMode == NVVM::SaturationMode::NONE)
3285 return op.emitOpError(
"FTZ with no saturation is not supported for f16 and "
3286 "vector<2xf16> additions/subtractions");
3295LogicalResult NVVM::FmaOp::verify() {
3296 auto opType = getRes().getType();
3297 mlir::NVVM::FPRoundingMode rndMode = getRnd();
3298 mlir::NVVM::SaturationMode satMode = getSat();
3299 bool isFTZ = getFtz();
3300 bool isRelu = getRelu();
3301 bool hasOOB = getOob();
3303 auto getBaseFType = [](
Type type) ->
Type {
3304 if (isa<VectorType>(type))
3305 return cast<VectorType>(type).getElementType();
3309 auto opBaseType = getBaseFType(opType);
3311 if (rndMode == NVVM::FPRoundingMode::NONE)
3312 return emitOpError(
"rounding mode must be specified");
3314 if (isRelu && satMode == NVVM::SaturationMode::SAT)
3315 return emitOpError(
"relu and saturation are not supported together");
3317 if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ))
3318 return emitOpError(
"oob is not supported with saturation or FTZ");
3320 if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB))
3321 return emitOpError(
"relu and oob are only supported for f16 and bf16");
3323 if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ))
3324 return emitOpError(
"FTZ and saturation are not supported for f64 type");
3326 if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN)
3328 "only RN rounding mode is supported for f16 and vector<2xf16>");
3330 if (opBaseType.isBF16()) {
3331 if (rndMode != NVVM::FPRoundingMode::RN)
3333 "only RN rounding mode is supported for bf16 and vector<2xbf16>");
3334 if (satMode != NVVM::SaturationMode::NONE || isFTZ)
3336 "FTZ and saturation are not supported for bf16 and vector<2xbf16>");
3342LogicalResult NVVM::SqrtOp::verify() {
3343 if (getRnd() == NVVM::FPRoundingMode::NONE)
3344 return emitOpError(
"rounding mode cannot be None");
3346 if (getRes().
getType().isF64() && getFtz())
3347 return emitOpError(
"FTZ is not supported for f64");
3358 unsigned sizeInBits,
3360 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3362 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3363 if (mask != 0xffffffffu)
3364 field = builder.CreateAnd(field, builder.getInt32(mask));
3366 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3367 field = builder.CreateShl(field, start);
3369 return builder.CreateOr(
result, field);
3372void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3374 llvm::IRBuilderBase &builder) {
3375 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3376 llvm::Value *smemDesc = builder.getInt64(0);
3381 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3383 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3389 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3393 mt.
mapValue(thisOp.getRes()) = smemDesc;
3400std::string NVVM::MBarrierInitOp::getPtx() {
3402 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3403 : std::string(
"mbarrier.init.b64 [%0], %1;");
3406std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3409 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3410 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3413std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3415 llvm::StringRef space = isShared ?
".shared" :
"";
3417 return llvm::formatv(
"{\n\t"
3418 ".reg .pred P1; \n\t"
3420 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3421 "@P1 bra.uni DONE; \n\t"
3422 "bra.uni LAB_WAIT; \n\t"
3439 LLVM::FNegOp::create(rewriter, loc, op.getRhs().getType(), op.getRhs());
3442 op.getRnd(), op.getSat(), op.getFtz());
3458 auto thisOp = cast<NVVM::BarrierOp>(op);
3459 llvm::Value *barrierId = thisOp.getBarrierId()
3461 : builder.getInt32(0);
3462 llvm::Intrinsic::ID id;
3464 if (thisOp.getNumberOfThreads()) {
3465 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3466 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3467 }
else if (thisOp.getReductionOp()) {
3468 switch (*thisOp.getReductionOp()) {
3469 case NVVM::BarrierReduction::AND:
3470 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3472 case NVVM::BarrierReduction::OR:
3473 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3475 case NVVM::BarrierReduction::POPC:
3476 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3479 args.push_back(builder.CreateICmpNE(
3480 mt.
lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3482 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3485 return {id, std::move(args)};
3490 llvm::IRBuilderBase &builder) {
3491 auto thisOp = cast<NVVM::CosOp>(op);
3492 llvm::Intrinsic::ID
id = thisOp.getFtz()
3493 ? llvm::Intrinsic::nvvm_cos_approx_ftz_f
3494 : llvm::Intrinsic::nvvm_cos_approx_f;
3500 llvm::IRBuilderBase &builder) {
3501 auto thisOp = cast<NVVM::SinOp>(op);
3502 llvm::Intrinsic::ID
id = thisOp.getFtz()
3503 ? llvm::Intrinsic::nvvm_sin_approx_ftz_f
3504 : llvm::Intrinsic::nvvm_sin_approx_f;
3510 llvm::IRBuilderBase &builder) {
3511 auto thisOp = cast<NVVM::Log2Op>(op);
3512 llvm::Intrinsic::ID
id = thisOp.getFtz()
3513 ? llvm::Intrinsic::nvvm_lg2_approx_ftz_f
3514 : llvm::Intrinsic::nvvm_lg2_approx_f;
3520 llvm::IRBuilderBase &builder) {
3521 auto thisOp = cast<NVVM::Ex2Op>(op);
3522 llvm::Intrinsic::ID
id = thisOp.getFtz()
3523 ? llvm::Intrinsic::nvvm_ex2_approx_ftz
3524 : llvm::Intrinsic::nvvm_ex2_approx;
3530 llvm::IRBuilderBase &builder) {
3531 auto thisOp = cast<NVVM::RsqrtOp>(op);
3532 Type t = thisOp.getRes().getType();
3533 bool isFtz = thisOp.getFtz();
3535 llvm::Intrinsic::ID
id = [&] {
3537 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
3538 : llvm::Intrinsic::nvvm_rsqrt_approx_f;
3541 return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
3542 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
3550 llvm::IRBuilderBase &builder) {
3551 auto thisOp = cast<NVVM::SqrtOp>(op);
3552 Type t = thisOp.getRes().getType();
3553 NVVM::FPRoundingMode rndMode = thisOp.getRnd();
3554 bool isFtz = thisOp.getFtz();
3558 unsigned rndIndex =
static_cast<unsigned>(rndMode) - 1;
3560 static constexpr llvm::Intrinsic::ID f32IDs[] = {
3561 llvm::Intrinsic::nvvm_sqrt_rn_f,
3562 llvm::Intrinsic::nvvm_sqrt_rm_f,
3563 llvm::Intrinsic::nvvm_sqrt_rp_f,
3564 llvm::Intrinsic::nvvm_sqrt_rz_f,
3566 static constexpr llvm::Intrinsic::ID f32FTZIDs[] = {
3567 llvm::Intrinsic::nvvm_sqrt_rn_ftz_f,
3568 llvm::Intrinsic::nvvm_sqrt_rm_ftz_f,
3569 llvm::Intrinsic::nvvm_sqrt_rp_ftz_f,
3570 llvm::Intrinsic::nvvm_sqrt_rz_ftz_f,
3572 static constexpr llvm::Intrinsic::ID f64IDs[] = {
3573 llvm::Intrinsic::nvvm_sqrt_rn_d,
3574 llvm::Intrinsic::nvvm_sqrt_rm_d,
3575 llvm::Intrinsic::nvvm_sqrt_rp_d,
3576 llvm::Intrinsic::nvvm_sqrt_rz_d,
3579 llvm::Intrinsic::ID
id =
3580 t.
isF32() ? (isFtz ? f32FTZIDs[rndIndex] : f32IDs[rndIndex])
3588 llvm::IRBuilderBase &builder) {
3589 auto thisOp = cast<NVVM::SqrtApproxOp>(op);
3590 llvm::Intrinsic::ID
id = thisOp.getFtz()
3591 ? llvm::Intrinsic::nvvm_sqrt_approx_ftz_f
3592 : llvm::Intrinsic::nvvm_sqrt_approx_f;
3598 llvm::IRBuilderBase &builder) {
3599 auto thisOp = cast<NVVM::PMEventOp>(op);
3603 llvm::Value *maskVal;
3604 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3605 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3606 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3609 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3612 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3617 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3619 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3620 : llvm::Intrinsic::nvvm_mbarrier_init;
3625 args.push_back(mt.
lookupValue(thisOp.getCount()));
3627 return {id, std::move(args)};
3632 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3634 llvm::Intrinsic::ID
id = isShared
3635 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3636 : llvm::Intrinsic::nvvm_mbarrier_inval;
3643 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3646 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3649 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3651 static constexpr llvm::Intrinsic::ID IDs[] = {
3652 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3653 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3654 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3655 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3660 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3662 return {IDs[
index], std::move(args)};
3667 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3670 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3673 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3675 static constexpr llvm::Intrinsic::ID IDs[] = {
3676 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3677 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3678 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3679 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3684 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3686 return {IDs[
index], std::move(args)};
3691 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3694 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3697 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3699 static constexpr llvm::Intrinsic::ID IDs[] = {
3700 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3701 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3702 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3703 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3704 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3705 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3706 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3707 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3709 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3710 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3714 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3721 bool hasCount =
static_cast<bool>(thisOp.getCount());
3723 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3724 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3728 llvm::Value *count =
3730 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3731 return {id, {mbar, count}};
3736 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3739 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3742 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3744 static constexpr llvm::Intrinsic::ID IDs[] = {
3745 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3746 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3747 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3748 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3749 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3750 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3752 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3754 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3756 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3757 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3761 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3767 bool hasCount =
static_cast<bool>(thisOp.getCount());
3768 llvm::Value *count =
3770 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3772 return {id, {mbar, count}};
3775bool MBarrierArriveExpectTxOp::getAsmValues(
3782 for (
auto val : getOperands())
3790 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3793 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3796 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3799 static constexpr llvm::Intrinsic::ID IDs[] = {
3800 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3801 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3802 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3803 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3804 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3805 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3806 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3807 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3808 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3810 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3813 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3814 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3819 return {id, {mbar, txcount}};
3824 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3827 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3830 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3833 static constexpr llvm::Intrinsic::ID IDs[] = {
3834 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3835 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3836 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3837 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3838 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3839 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3840 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3841 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3842 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3844 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3847 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3848 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3853 return {id, {mbar, txcount}};
3858 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3860 llvm::Intrinsic::ID
id =
3861 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3862 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3866 args.push_back(mt.
lookupValue(thisOp.getCount()));
3868 return {id, std::move(args)};
3873 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3875 llvm::Intrinsic::ID
id =
3876 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3877 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3881 args.push_back(mt.
lookupValue(thisOp.getCount()));
3883 return {id, std::move(args)};
3888 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3889 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3890 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3893 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3896 static constexpr llvm::Intrinsic::ID IDs[] = {
3897 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3898 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3899 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3900 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3901 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3902 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3903 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3904 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3905 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3907 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3910 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3911 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3916 return {id, {mbar, input}};
3921 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3922 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3923 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3924 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3928 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3929 (isPhaseParity ? 1 : 0);
3932 static constexpr llvm::Intrinsic::ID IDs[] = {
3933 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3934 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3935 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3936 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3937 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3938 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3939 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3940 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3941 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3942 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3943 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3944 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3945 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3946 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3947 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3948 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3949 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3951 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3954 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3961 args.push_back(mbar);
3962 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
3964 args.push_back(mt.
lookupValue(thisOp.getTicks()));
3966 return {id, std::move(args)};
3971 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3974 llvm::Intrinsic::ID id;
3975 if (thisOp.getNoinc()) {
3976 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3977 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3979 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3980 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3988 llvm::IRBuilderBase &builder) {
3989 auto thisOp = cast<NVVM::MovMatrixOp>(op);
3990 return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
3994#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3995 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3997#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3998 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
4003 llvm::Intrinsic::ID id;
4005 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
4006 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
4007 switch (cpAsyncOp.getSize()) {
4015 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
4020 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
4024 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
4025 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
4027 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
4034 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
4036 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
4039 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4043 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4044 llvm::Value *i64Unused =
4045 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4046 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4047 args.push_back(builder.getInt1(hasCacheHint));
4049 return {id, std::move(args)};
4054 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
4058 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4060 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4064 mlir::Value multicastMask = thisOp.getMulticastMask();
4065 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
4068 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
4069 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
4075 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4076 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
4077 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4081 args.push_back(builder.getInt1(hasMulticastMask));
4082 args.push_back(builder.getInt1(hasCacheHint));
4084 llvm::Intrinsic::ID
id =
4086 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
4087 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
4089 return {id, std::move(args)};
4094 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
4096 llvm::Intrinsic::ID
id =
4097 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
4100 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4101 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4105 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4106 llvm::Value *i64Unused =
4107 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4108 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4109 args.push_back(builder.getInt1(hasCacheHint));
4112 if (
mlir::Value byteMask = thisOp.getByteMask()) {
4114 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
4117 return {id, std::move(args)};
4120bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
4127 for (
auto val : getOperands())
4134CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
4136 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
4137 const bool isCTAOnly = thisOp.getIsCTAOnly();
4141 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
4143 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4153 const bool hasMC =
static_cast<bool>(mcMask);
4154 llvm::Value *i16Zero =
4155 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
4159 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4160 llvm::Value *i64Zero =
4161 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4167 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
4169 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
4173 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
4174 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
4175 args.push_back(builder.getInt1(hasMC));
4176 args.push_back(builder.getInt1(hasCacheHint));
4180 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
4181 args.push_back(builder.getInt1(hasCacheHint));
4184 constexpr size_t numDims = 5;
4185 constexpr size_t numModes = 5;
4186 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
4187 using TableTy = std::array<rowTy, numModes>;
4188 static constexpr TableTy IDTable{
4189 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
4190 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
4191 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
4192 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
4193 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
4195 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
4196 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
4197 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
4199 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
4200 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
4201 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
4203 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
4204 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
4205 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
4207 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
4209 static constexpr TableTy IDTableCTA{
4211 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
4212 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
4213 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
4214 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
4215 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
4217 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
4218 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
4219 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
4221 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
4222 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
4223 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
4225 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
4226 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
4227 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
4229 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
4232 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
4233 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
4234 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
4235 size_t mode =
static_cast<size_t>(thisOp.getMode());
4236 size_t dim = thisOp.getCoordinates().size();
4237 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
4239 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
4241 return {id, std::move(args)};
4246 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
4250 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4252 for (
auto v : thisOp.getCoordinates())
4254 for (
auto v : thisOp.getIm2colOffsets())
4258 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4259 llvm::Value *i64Unused =
4260 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4261 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4262 args.push_back(builder.getInt1(hasCacheHint));
4264 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4265 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4266 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
4267 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
4268 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
4269 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
4270 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
4272 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
4273 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
4274 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
4276 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
4277 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
4278 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
4280 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
4281 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
4282 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
4283 {NI, NI, NI, NI, NI,
4284 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
4286 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
4287 "TMALoadModes must match number of rows in IDTable");
4288 size_t mode =
static_cast<size_t>(thisOp.getMode());
4289 size_t dim = thisOp.getCoordinates().size();
4290 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4291 if (
id == llvm::Intrinsic::not_intrinsic)
4292 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
4294 return {id, std::move(args)};
4298CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
4300 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
4304 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4305 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4307 for (
auto v : thisOp.getCoordinates())
4311 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4312 llvm::Value *i64Unused =
4313 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
4314 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
4315 args.push_back(builder.getInt1(hasCacheHint));
4317 const unsigned NI = llvm::Intrinsic::not_intrinsic;
4318 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
4319 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
4320 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
4321 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
4322 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
4323 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
4324 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
4325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
4326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
4327 {NI, NI, NI, NI, NI,
4328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
4330 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
4331 "TMAStoreModes must match number of rows in IDTable");
4332 size_t mode =
static_cast<size_t>(thisOp.getMode());
4333 size_t dim = thisOp.getCoordinates().size();
4334 llvm::Intrinsic::ID
id = IDTable[mode][dim];
4335 if (
id == llvm::Intrinsic::not_intrinsic)
4337 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
4339 return {id, std::move(args)};
4344 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
4352 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
4353 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
4355 for (
Value v : thisOp.getCoordinates())
4359 const bool hasCacheHint =
static_cast<bool>(cacheHint);
4360 llvm::Value *i64ZeroValue =
4361 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
4362 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
4363 args.push_back(builder.getInt1(hasCacheHint));
4365 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
4367 constexpr unsigned numRedKinds = 8;
4368 constexpr unsigned numLayouts = 2;
4369 constexpr unsigned maxDim = 5;
4370 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
4371 using layoutTable = std::array<row, numLayouts>;
4372 using fullTable = std::array<layoutTable, numRedKinds>;
4373 static constexpr fullTable IDTable{
4376 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
4377 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
4378 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
4379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
4380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
4382 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
4383 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
4384 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
4387 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
4388 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
4389 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
4390 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
4391 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
4393 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
4394 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
4395 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
4398 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
4399 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
4400 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
4401 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
4402 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
4404 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
4405 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
4406 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
4409 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
4410 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
4411 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
4412 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
4413 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
4415 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
4416 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
4417 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
4420 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
4421 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
4422 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
4423 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
4424 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
4426 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
4427 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
4428 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
4431 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
4432 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
4433 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
4434 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
4435 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
4437 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
4438 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
4439 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
4442 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
4443 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
4444 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
4445 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
4446 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
4448 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
4449 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
4450 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
4453 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
4454 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
4455 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
4456 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
4457 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
4459 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
4460 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
4462 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
4464 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
4465 "TMAReduxKinds must match number of rows in IDTable");
4467 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
4468 size_t mode =
static_cast<size_t>(thisOp.getMode());
4469 size_t dim = thisOp.getCoordinates().size();
4471 assert(redKind < IDTable.size() &&
4472 "Invalid redKind for CpAsyncBulkTensorReduceOp");
4473 assert(mode < IDTable[redKind].size() &&
4474 "Invalid mode for CpAsyncBulkTensorReduceOp");
4475 assert(dim < IDTable[redKind][mode].size() &&
4476 "Invalid dim for CpAsyncBulkTensorReduceOp");
4478 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
4481 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4483 return {intrinsicID, std::move(args)};
4488#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4489 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4490 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4492#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4493 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4494 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4497ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4498 NVVM::SaturationMode sat,
bool hasRelu) {
4499 using RndMode = NVVM::FPRoundingMode;
4500 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4509 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4514ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4516 llvm::IRBuilderBase &builder) {
4521 bool hasRelu = op.getRelu();
4523 llvm::Intrinsic::ID intId =
4524 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4525 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4527 return {intId, std::move(args)};
4530#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4531 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4532 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4534llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4537 .Case([&](mlir::Float6E2M3FNType) {
4540 .Case([&](mlir::Float6E3M2FNType) {
4544 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4545 return llvm::Intrinsic::not_intrinsic;
4550ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
4552 llvm::IRBuilderBase &builder) {
4554 bool hasRelu = op.getRelu();
4556 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4558 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4559 intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
4560 : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
4565 return {intId, std::move(args)};
4569ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
4571 llvm::IRBuilderBase &builder) {
4573 bool hasRelu = op.getRelu();
4575 llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
4577 if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
4578 intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
4579 : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
4584 return {intId, std::move(args)};
4587llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4590 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4591 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
4592 : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
4594 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4595 return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
4596 : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
4599 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF6x2Op");
4600 return llvm::Intrinsic::not_intrinsic;
4604llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4607 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4609 ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
4610 : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
4612 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4614 ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
4615 : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
4618 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF6x2Op");
4619 return llvm::Intrinsic::not_intrinsic;
4623#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4624 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4625 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4627#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4628 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4629 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4632ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4633 NVVM::SaturationMode sat,
bool hasRelu) {
4634 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4635 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4636 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4639 .Case([&](mlir::Float8E4M3FNType) {
4642 .Case([&](mlir::Float8E5M2Type) {
4645 .Case([&](mlir::Float8E8M0FNUType) {
4646 if (hasRoundingModeRZ)
4648 else if (hasRoundingModeRP)
4651 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4654 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4655 return llvm::Intrinsic::not_intrinsic;
4659#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4660 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4661 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4663llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4666 .Case([&](mlir::Float8E4M3FNType) {
4669 .Case([&](mlir::Float8E5M2Type) {
4673 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4674 return llvm::Intrinsic::not_intrinsic;
4679ConvertBF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4680 NVVM::FPRoundingMode rnd,
4681 NVVM::SaturationMode sat,
bool hasRelu) {
4682 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4684 static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
4685 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
4686 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
4687 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
4688 llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
4692 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4694 ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
4695 : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
4697 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4699 ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
4700 : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
4702 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4703 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4704 unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
4705 return ue8m0x2IDs[
index];
4708 llvm_unreachable(
"Invalid conversion in ConvertBF16x2ToF8x2Op");
4709 return llvm::Intrinsic::not_intrinsic;
4715 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4717 bool hasRelu = curOp.getRelu();
4719 llvm::Intrinsic::ID intId =
4721 .Case([&](Float8E4M3FNType type) {
4722 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4723 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4725 .Case([&](Float8E5M2Type type) {
4726 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4727 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4730 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4731 return llvm::Intrinsic::not_intrinsic;
4734 llvm::Value *packedI16 =
4735 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4736 llvm::Type::getInt16Ty(builder.getContext()));
4738 return {intId, {packedI16}};
4743 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4745 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4746 llvm::Value *packedI16 =
4747 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4748 llvm::Type::getInt16Ty(builder.getContext()));
4750 return {intId, {packedI16}};
4755 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4757 bool hasRelu = curOp.getRelu();
4759 llvm::Intrinsic::ID intId =
4761 .Case([&](Float6E2M3FNType type) {
4762 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4763 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4765 .Case([&](Float6E3M2FNType type) {
4766 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4767 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4770 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4771 return llvm::Intrinsic::not_intrinsic;
4774 llvm::Value *packedI16 =
4775 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4776 llvm::Type::getInt16Ty(builder.getContext()));
4778 return {intId, {packedI16}};
4783 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4785 bool hasRelu = curOp.getRelu();
4787 llvm::Intrinsic::ID intId =
4789 .Case([&](Float4E2M1FNType type) {
4790 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4791 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4794 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4795 return llvm::Intrinsic::not_intrinsic;
4798 llvm::Value *extendedI16 =
4799 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4800 llvm::Type::getInt16Ty(builder.getContext()));
4802 return {intId, {extendedI16}};
4807 auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
4808 bool hasRelu = thisOp.getRelu();
4809 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4811 llvm::Intrinsic::ID
id =
4813 ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4814 : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4820 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4821 : builder.getInt16(0x7f7f));
4822 return {id, std::move(args)};
4827 auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
4828 bool hasRelu = thisOp.getRelu();
4829 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4831 llvm::Intrinsic::ID
id =
4834 nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
4835 : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
4840 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4841 : builder.getInt16(0x7f7f));
4842 return {id, std::move(args)};
4847 auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
4848 bool hasRelu = thisOp.getRelu();
4849 bool hasScale =
static_cast<bool>(thisOp.getScaleFactor());
4850 bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
4852 static constexpr llvm::Intrinsic::ID ids[] = {
4853 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
4854 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
4855 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
4856 llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
4859 unsigned idx = (hasSat << 1) | hasRelu;
4863 llvm::Value *packedI16 =
4864 builder.CreateBitCast(mt.
lookupValue(thisOp.getSrc()),
4865 llvm::Type::getInt16Ty(builder.getContext()));
4866 args.push_back(packedI16);
4867 args.push_back(hasScale ? mt.
lookupValue(thisOp.getScaleFactor())
4868 : builder.getInt16(0x7f7f));
4870 return {ids[idx], std::move(args)};
4874Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
4877 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4878 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4880 bool isShared = as == NVVMMemorySpace::Shared;
4881 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4883 llvm::Intrinsic::ID id;
4885 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4886 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4888 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4889 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4899llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4902 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4903 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4904 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4905 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4914#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4915 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4916 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4918#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4919 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4920 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4923Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
4926 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4927 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4929 bool isShared = as == NVVMMemorySpace::Shared;
4930 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
4931 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4933 llvm::Intrinsic::ID
id =
4940 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
4945#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4946 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4948#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4949 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4950 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4952#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4954 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4955 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4956 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4957 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4958 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4962ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4964 llvm::IRBuilderBase &builder) {
4965 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4966 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4967 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4968 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4969 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4971 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4972 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4973 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4974 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4975 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4977 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4978 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4979 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4980 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4981 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4984 unsigned hasRelu = op.getRelu() ? 1 : 0;
4985 unsigned hasSatFinite =
4986 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4989 unsigned idx = (hasSatFinite << 1) | hasRelu;
4994 if (op.getRandomBits())
4995 args.push_back(mt.
lookupValue(op.getRandomBits()));
4997 switch (op.getRnd()) {
4998 case FPRoundingMode::RN:
4999 return {rndRNIds[idx], std::move(args)};
5000 case FPRoundingMode::RZ:
5001 return {rndRZIds[idx], std::move(args)};
5002 case FPRoundingMode::RS:
5003 return {rndRSIds[idx], std::move(args)};
5005 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
5010ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
5012 llvm::IRBuilderBase &builder) {
5013 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
5014 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
5015 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
5016 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
5017 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
5019 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
5020 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
5021 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
5022 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
5023 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
5025 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
5026 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
5027 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
5028 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
5029 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
5032 unsigned hasRelu = op.getRelu() ? 1 : 0;
5033 unsigned hasSatFinite =
5034 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
5037 unsigned idx = (hasSatFinite << 1) | hasRelu;
5042 if (op.getRandomBits())
5043 args.push_back(mt.
lookupValue(op.getRandomBits()));
5045 switch (op.getRnd()) {
5046 case FPRoundingMode::RN:
5047 return {rndRNIds[idx], std::move(args)};
5048 case FPRoundingMode::RZ:
5049 return {rndRZIds[idx], std::move(args)};
5050 case FPRoundingMode::RS:
5051 return {rndRSIds[idx], std::move(args)};
5053 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
5057llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
5059 bool hasRelu = getRelu();
5062 .Case([&](mlir::Float8E4M3FNType) {
5063 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
5064 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
5066 .Case([&](mlir::Float8E5M2Type) {
5067 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
5068 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
5071 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
5072 return llvm::Intrinsic::not_intrinsic;
5076llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
5078 bool hasRelu = getRelu();
5081 .Case([&](mlir::Float6E2M3FNType) {
5082 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
5083 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
5085 .Case([&](mlir::Float6E3M2FNType) {
5086 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
5087 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
5090 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
5091 return llvm::Intrinsic::not_intrinsic;
5095llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
5097 bool hasRelu = getRelu();
5100 .Case([&](mlir::Float4E2M1FNType) {
5101 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
5102 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
5105 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
5106 return llvm::Intrinsic::not_intrinsic;
5110llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
5111 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
5112 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
5113 auto srcFmt = curOp.getSrcFormat();
5114 auto mc = curOp.getMulticast();
5116 switch (curOp.getShape()) {
5117 case Tcgen05CpShape::SHAPE_128x256b:
5119 case Tcgen05CpShape::SHAPE_128x128b:
5121 case Tcgen05CpShape::SHAPE_4x256b:
5123 case Tcgen05CpShape::SHAPE_32x128b:
5125 case Tcgen05CpShape::SHAPE_64x128b:
5126 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
5130 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
5137 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
5139 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
5144LogicalResult Tcgen05LdOp::verify() {
5146 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5149 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
5150 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
5152 auto resTy = getRes().getType();
5153 unsigned resLen = isa<VectorType>(resTy)
5154 ? llvm::cast<VectorType>(resTy).getNumElements()
5157 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
5158 "{1} in tcgen05.ld Op",
5159 resLen, stringifyEnum(
getShape())));
5164LogicalResult Tcgen05StOp::verify() {
5166 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
5169 auto valTy = getVal().getType();
5170 unsigned valLen = isa<VectorType>(valTy)
5171 ? llvm::cast<VectorType>(valTy).getNumElements()
5174 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
5175 "{1} in tcgen05.st Op",
5176 valLen, stringifyEnum(
getShape())));
5186 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
5187 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
5188 rangeAttr.getLower(), rangeAttr.getUpper()});
5198 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
5202 const llvm::APInt &lower = rangeAttr->getLower();
5203 const llvm::APInt &upper = rangeAttr->getUpper();
5206 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
5207 unsigned bitWidth = lower.getBitWidth();
5208 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
5209 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
5211 "invalid range attribute: Lower == Upper, but they aren't min (")
5212 << llvm::toString(minVal, 10,
false) <<
") or max ("
5213 << llvm::toString(maxVal, 10,
false)
5214 <<
") value! This is an invalid constant range.";
5221 llvm::IRBuilderBase &builder) {
5222 return builder.CreateBitCast(arg,
5223 llvm::Type::getInt32Ty(builder.getContext()));
5228 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
5235 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5236 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5237 unsigned type = (isASigned << 1) | isBSigned;
5238 const llvm::Intrinsic::ID ids[] = {
5239 llvm::Intrinsic::nvvm_idp4a_u_u,
5240 llvm::Intrinsic::nvvm_idp4a_u_s,
5241 llvm::Intrinsic::nvvm_idp4a_s_u,
5242 llvm::Intrinsic::nvvm_idp4a_s_s,
5244 return {ids[type], args};
5249 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
5254 args.push_back(builder.getInt1(curOp.getBHi()));
5257 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
5258 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
5259 unsigned type = (isASigned << 1) | isBSigned;
5260 const llvm::Intrinsic::ID ids[] = {
5261 llvm::Intrinsic::nvvm_idp2a_u_u,
5262 llvm::Intrinsic::nvvm_idp2a_u_s,
5263 llvm::Intrinsic::nvvm_idp2a_s_u,
5264 llvm::Intrinsic::nvvm_idp2a_s_s,
5266 return {ids[type], args};
5270 llvm::IRBuilderBase &builder) {
5271 return builder.CreateAddrSpaceCast(
5272 addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
5276PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
5278 llvm::IRBuilderBase &builder) {
5279 using MemSpace = NVVM::NVVMMemorySpace;
5280 using CacheLevel = NVVM::PrefetchCacheLevel;
5282 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
5283 std::optional<NVVM::CacheEvictionPriority> evictPriority =
5284 op.getEvictPriority();
5285 unsigned addressSpace =
5286 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
5294 if (op.getTensormap())
5295 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
5297 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
5299 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
5300 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
5302 if (evictPriority && *cacheLevel == CacheLevel::L2) {
5303 switch (*evictPriority) {
5304 case NVVM::CacheEvictionPriority::EvictLast:
5305 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
5306 case NVVM::CacheEvictionPriority::EvictNormal:
5307 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
5309 llvm_unreachable(
"Invalid cache eviction priority");
5313 switch (
static_cast<MemSpace
>(addressSpace)) {
5314 case MemSpace::Generic:
5315 return *cacheLevel == CacheLevel::L1
5317 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
5318 case MemSpace::Global:
5319 return *cacheLevel == CacheLevel::L1
5321 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
5323 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
5324 case MemSpace::Local:
5325 return *cacheLevel == CacheLevel::L1
5327 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
5329 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
5331 llvm_unreachable(
"Invalid pointer address space");
5335bool NVVM::InlinePtxOp::getAsmValues(
5339 for (
auto arg : getReadWriteArgs())
5341 for (
auto arg : getResults())
5343 for (
auto arg : getReadOnlyArgs())
5350NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
5352 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
5354 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
5355 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
5357 llvm::Intrinsic::ID intrinsicID =
5358 curOp.getMulticast()
5360 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
5361 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
5363 return {intrinsicID, args};
5366NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
5368 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
5370 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
5372 llvm::Intrinsic::ID intrinsicID;
5374 switch (curOp.getQueryType()) {
5375 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
5377 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
5379 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
5380 intrinsicID = llvm::Intrinsic::
5381 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
5383 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
5384 intrinsicID = llvm::Intrinsic::
5385 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
5387 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
5388 intrinsicID = llvm::Intrinsic::
5389 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
5392 return {intrinsicID, args};
5397 llvm::IRBuilderBase &builder) {
5398 auto thisOp = cast<NVVM::PermuteOp>(op);
5399 NVVM::PermuteMode mode = thisOp.getMode();
5401 static constexpr llvm::Intrinsic::ID IDs[] = {
5402 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
5403 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
5404 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
5405 llvm::Intrinsic::nvvm_prmt_rc16};
5407 unsigned modeIndex =
static_cast<unsigned>(mode);
5415 args.push_back(mt.
lookupValue(thisOp.getSelector()));
5417 return {IDs[modeIndex], args};
5422 auto thisOp = cast<NVVM::TensormapReplaceOp>(op);
5426 if (thisOp.getOrd())
5427 args.push_back(builder.getInt32(thisOp.getOrd().value()));
5428 if (thisOp.getNewValue())
5429 args.push_back(mt.
lookupValue(thisOp.getNewValue()));
5430 if (
auto attr = thisOp.getNewValueAttr()) {
5433 .Case<TensormapElemtypeAttr, TensormapInterleaveLayoutAttr,
5434 TensormapSwizzleModeAttr, TensormapSwizzleAtomicityAttr,
5435 TensormapFillModeAttr>([](
auto attr) {
5436 return static_cast<unsigned>(attr.getValue());
5438 .Default([](
auto attr) {
5439 llvm_unreachable(
"Invalid attribute type");
5442 args.push_back(builder.getInt32(val));
5445 static constexpr llvm::Intrinsic::ID IDs[] = {
5446 llvm::Intrinsic::nvvm_tensormap_replace_global_address,
5447 llvm::Intrinsic::nvvm_tensormap_replace_rank,
5448 llvm::Intrinsic::nvvm_tensormap_replace_box_dim,
5449 llvm::Intrinsic::nvvm_tensormap_replace_global_dim,
5450 llvm::Intrinsic::nvvm_tensormap_replace_global_stride,
5451 llvm::Intrinsic::nvvm_tensormap_replace_element_stride,
5452 llvm::Intrinsic::nvvm_tensormap_replace_elemtype,
5453 llvm::Intrinsic::nvvm_tensormap_replace_interleave_layout,
5454 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_mode,
5455 llvm::Intrinsic::nvvm_tensormap_replace_swizzle_atomicity,
5456 llvm::Intrinsic::nvvm_tensormap_replace_fill_mode,
5459 unsigned fieldIndex =
static_cast<unsigned>(thisOp.getField());
5461 return {IDs[fieldIndex], args};
5470 llvm::IRBuilderBase &builder) {
5472 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
5475 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5478 const bool isATensor = isa<llvm::PointerType>(
A->getType());
5481 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5482 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5483 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5485 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5486 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5487 using IsATensorArray = std::array<CtaGroupArray, 2>;
5488 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5489 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5492 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
5498 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
5500 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
5504 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5505 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5509 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
5510 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
5516 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
5518 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
5522 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5523 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5527 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
5528 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
5534 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
5537 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
5542 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
5544 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
5549 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
5551 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
5557 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
5561 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
5566 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
5568 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
5572 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
5574 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
5577 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5578 bool hasScaleInputD = ScaleInputD !=
nullptr;
5580 llvm::Value *DisableOutputLane =
5582 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5584 const unsigned ctaGroup =
5587 llvm::Intrinsic::ID ID =
5588 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5589 [ctaGroup - 1][thisOp.getAShift()];
5591 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
5594 args.push_back(ScaleInputD);
5596 if (hasDisableOutputLane)
5597 args.push_back(DisableOutputLane);
5599 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5601 if (!hasDisableOutputLane)
5602 args.push_back(builder.getInt32(ctaGroup));
5605 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5612 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
5613 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
5615 if (disableOutputLane) {
5616 mlir::VectorType disableOutputLaneType =
5617 cast<mlir::VectorType>(disableOutputLane.
getType());
5618 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
5619 disableOutputLaneType.getNumElements() != 4) ||
5620 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
5621 disableOutputLaneType.getNumElements() != 8))
5622 return emitError(loc) <<
"Disable Output Lane of length "
5623 << disableOutputLaneType.getNumElements()
5624 <<
" is incompatible with CtaGroupAttr";
5627 if (hasAShift && !isATensor)
5629 loc,
"A-shift can be applied only when matrix A is in tensor memory");
5631 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
5632 collectorOp == Tcgen05MMACollectorOp::USE))
5634 loc,
"Cannot use collector buffer operation fill or use with ashift");
5639LogicalResult Tcgen05MMAOp::verify() {
5641 getDisableOutputLane(), getCtaGroup(), getAShift(),
5642 getCollectorOp(), getLoc());
5652 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
5655 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5658 bool isATensor = isa<llvm::PointerType>(
A->getType());
5661 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5662 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5663 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5664 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5666 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
5667 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
5668 using IsATensorArray = std::array<CtaGroupArray, 2>;
5669 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
5670 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
5673 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
5679 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
5681 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
5685 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5686 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5690 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5691 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5697 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5700 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5705 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5706 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5710 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5711 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5718 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5722 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5727 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5729 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5734 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5736 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5742 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5746 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5751 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5753 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5757 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5759 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5762 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5763 bool hasScaleInputD = ScaleInputD !=
nullptr;
5765 llvm::Value *DisableOutputLane =
5767 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5772 llvm::Intrinsic::ID ID =
5773 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5774 [ctaGroup - 1][thisOp.getAShift()];
5776 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5779 args.push_back(ScaleInputD);
5781 if (hasDisableOutputLane)
5782 args.push_back(DisableOutputLane);
5784 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5786 if (!hasDisableOutputLane)
5787 args.push_back(builder.getInt32(ctaGroup));
5790 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5795LogicalResult Tcgen05MMASparseOp::verify() {
5797 getDisableOutputLane(), getCtaGroup(), getAShift(),
5798 getCollectorOp(), getLoc());
5808 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5811 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5814 bool isATensor = isa<llvm::PointerType>(
A->getType());
5817 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5818 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5819 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5820 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5821 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5822 args.push_back(builder.getInt32(
5825 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5827 auto kind = thisOp.getKind();
5828 auto blockScale = thisOp.getBlockScale();
5829 llvm::Intrinsic::ID ID = [&]() {
5830 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5831 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5832 return isATensor ? llvm::Intrinsic::
5833 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5835 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5836 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5839 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5841 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5843 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5844 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5846 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5847 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5848 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5849 return isATensor ? llvm::Intrinsic::
5850 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5852 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5854 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5855 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5858 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5860 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5862 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5865 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5867 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5870 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
5877 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
5878 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
5879 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5880 kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
5881 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
5883 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5884 kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
5886 llvm::formatv(
"{} kind does not support block16 attribute",
5887 stringifyEnum(kind)));
5892LogicalResult Tcgen05MMABlockScaleOp::verify() {
5894 getBlockScale(), getLoc());
5904 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5907 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5910 bool isATensor = isa<llvm::PointerType>(
A->getType());
5913 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5914 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5915 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5916 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5917 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5918 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5919 args.push_back(builder.getInt32(
5922 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5924 auto kind = thisOp.getKind();
5925 auto blockScale = thisOp.getBlockScale();
5926 llvm::Intrinsic::ID ID = [&]() {
5927 if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
5928 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5929 return isATensor ? llvm::Intrinsic::
5930 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5932 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5933 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5936 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5938 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5940 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
5941 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5942 return isATensor ? llvm::Intrinsic::
5943 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5945 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5946 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5949 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5951 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5953 }
else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
5954 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5957 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5959 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5961 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5964 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5966 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5969 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
5975LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5977 getBlockScale(), getLoc());
5987 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5990 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5993 bool isATensor = isa<llvm::PointerType>(
A->getType());
5996 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5997 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5998 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6000 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6004 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
6005 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
6007 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
6008 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
6010 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
6012 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6014 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6026 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
6029 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
6032 bool isATensor = isa<llvm::PointerType>(
A->getType());
6035 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
6036 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
6037 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
6038 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
6040 mlir::Value ZeroColMask = thisOp.getZeroColMask();
6045 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
6046 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
6048 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
6049 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
6051 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
6053 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
6055 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
6064#define TCGEN05LDRED(SHAPE, NUM, TYPE) \
6065 llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE
6069 auto thisOp = cast<NVVM::Tcgen05LdRedOp>(op);
6072 mlir::VectorType VecResTy =
6073 cast<mlir::VectorType>(thisOp.getData().getType());
6074 unsigned Num = VecResTy.getNumElements();
6075 bool IsFloat = thisOp.getRedVal().getType().isF32();
6077 llvm::Intrinsic::ID Shape32x32b[][2] = {
6088 llvm::Intrinsic::ID Shape16x32bx2[][2] = {
6099 NVVM::Tcgen05LdStShape
shape = thisOp.getShape();
6100 unsigned ID = [&]() {
6103 unsigned idx = std::log2(Num);
6105 case NVVM::Tcgen05LdStShape::SHAPE_32X32B:
6106 return Shape32x32b[idx][IsFloat];
6107 case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2:
6108 return Shape16x32bx2[idx][IsFloat];
6110 llvm_unreachable(
"unhandled tcgen05.ld lowering");
6116 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2)
6117 args.push_back(mt.
lookupValue(thisOp.getOffset()));
6120 builder.getInt32(thisOp.getOp() == NVVM::ReductionKind::MIN ? 0 : 1));
6123 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getAbs())));
6124 args.push_back(builder.getInt1(
static_cast<unsigned>(thisOp.getNan())));
6129LogicalResult Tcgen05LdRedOp::verify() {
6130 VectorType data = cast<VectorType>(getData().
getType());
6131 Type redVal = getRedVal().getType();
6133 if (data.getElementType() != redVal)
6135 "type of reduction value and element type of vector data should match");
6137 if (getOp() != NVVM::ReductionKind::MIN &&
6138 getOp() != NVVM::ReductionKind::MAX)
6139 return emitError(
"only min and max reduction kinds are supported");
6141 if (redVal.
isInteger() && (getAbs() || getNan())) {
6142 return emitError(
"abs or nan is only applicable for f32 type");
6152void NVVMDialect::initialize() {
6155#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6158#define GET_ATTRDEF_LIST
6159#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
6164 allowUnknownOperations();
6165 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
6166 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
6169LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
6171 StringAttr attrName = attr.
getName();
6173 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
6174 if (!isa<LLVM::LLVMFuncOp>(op)) {
6175 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
6176 <<
"' attribute attached to unexpected op";
6181 if (attrName == NVVMDialect::getMaxntidAttrName() ||
6182 attrName == NVVMDialect::getReqntidAttrName() ||
6183 attrName == NVVMDialect::getClusterDimAttrName()) {
6184 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
6185 if (!values || values.empty() || values.size() > 3) {
6188 <<
"' attribute must be integer array with maximum 3 index";
6193 if (attrName == NVVMDialect::getMinctasmAttrName() ||
6194 attrName == NVVMDialect::getMaxnregAttrName() ||
6195 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
6196 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
6198 <<
"'" << attrName <<
"' attribute must be integer constant";
6202 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
6203 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
6204 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
6206 <<
"'" << attrName <<
"' attribute must be used along with "
6207 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
6208 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
6215LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
6216 unsigned regionIndex,
6219 auto funcOp = dyn_cast<FunctionOpInterface>(op);
6223 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
6224 StringAttr attrName = argAttr.
getName();
6225 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
6229 <<
"' attribute must be present only on kernel arguments";
6231 if (!isa<UnitAttr>(argAttr.
getValue()))
6232 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
6233 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
6236 <<
"' attribute requires the argument to also have attribute '"
6237 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
6248unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
6249 return static_cast<unsigned>(getValue());
6252bool NVVMMemorySpaceAttr::isValidLoad(
6253 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6254 const ::mlir::DataLayout *dataLayout,
6260bool NVVMMemorySpaceAttr::isValidStore(
6261 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
6262 const ::mlir::DataLayout *dataLayout,
6268bool NVVMMemorySpaceAttr::isValidAtomicOp(
6269 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
6270 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
6273 assert(
false &&
"unimplemented, see TODO in the source.");
6277bool NVVMMemorySpaceAttr::isValidAtomicXchg(
6278 Type type, ptr::AtomicOrdering successOrdering,
6279 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
6280 const ::mlir::DataLayout *dataLayout,
6283 assert(
false &&
"unimplemented, see TODO in the source.");
6287bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
6291 assert(
false &&
"unimplemented, see TODO in the source.");
6295bool NVVMMemorySpaceAttr::isValidPtrIntCast(
6300 assert(
false &&
"unimplemented, see TODO in the source.");
6309 int optLevel, StringRef triple, StringRef chip,
6310 StringRef features, DictionaryAttr flags,
6312 if (optLevel < 0 || optLevel > 3) {
6313 emitError() <<
"The optimization level must be a number between 0 and 3.";
6316 if (triple.empty()) {
6317 emitError() <<
"The target triple cannot be empty.";
6321 emitError() <<
"The target chip cannot be empty.";
6324 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
6325 return mlir::isa_and_nonnull<StringAttr>(attr);
6327 emitError() <<
"All the elements in the `link` array must be strings.";
6333LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
6334 if (!getVerifyTarget())
6337 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
6340 "NVVM target attribute must be attached to a GPU module");
6343 const unsigned targetFullSmVersion =
6347 "Minimum NVVM target SM version is sm_20");
6351 ->
walk([&](Operation *op) {
6352 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
6353 const NVVMCheckSMVersion requirement =
6354 reqOp.getRequiredMinSMVersion();
6356 op->
emitOpError() <<
"is not supported on " << getChip();
6368#define GET_OP_CLASSES
6369#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
6371#define GET_ATTRDEF_CLASSES
6372#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred, TypeRange actual)
For ops with optional results, allow the user to omit the result even when inference would produce on...
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyAddSubFOp(OpType op)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static LogicalResult inferMBarrierArriveResultTypes(MLIRContext *context, Value addr, SmallVectorImpl< Type > &inferredReturnTypes)
Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) retu...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(SubFOp op, PatternRewriter &rewriter) const override
static bool isMinimumSMVersion(unsigned fullSmVersion)
static unsigned getTargetFullSmVersionFromStr(StringRef smVersionString)
bool isCompatibleWith(const unsigned &targetFullSmVersion) const
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.