31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
74 NVVMMemorySpace targetAS) {
75 unsigned AS =
static_cast<unsigned>(targetAS);
76 return builder.CreateAddrSpaceCast(
77 ptr, llvm::PointerType::get(builder.getContext(), AS));
81static llvm::nvvm::CTAGroupKind
84 case NVVM::CTAGroupKind::CTA_1:
85 return llvm::nvvm::CTAGroupKind::CG_1;
86 case NVVM::CTAGroupKind::CTA_2:
87 return llvm::nvvm::CTAGroupKind::CG_2;
89 llvm_unreachable(
"unsupported cta_group value");
101 size_t numIm2ColOffsets,
103 if (tensorDims < 1 || tensorDims > 5)
104 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
112 "to use im2col mode, the tensor has to be at least 3-dimensional");
114 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
116 loc,
"im2col offsets must be 2 less than number of coordinates");
121LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
122 TMAStoreMode mode = getMode();
126 if (getPredicate()) {
127 if (mode != TMAStoreMode::TILE)
128 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
129 if (getL2CacheHint())
130 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
135 case TMAStoreMode::TILE:
137 case TMAStoreMode::IM2COL:
139 case TMAStoreMode::TILE_SCATTER4:
141 return emitError(
"Scatter4 mode expects 5 coordinates");
146LogicalResult CpAsyncOp::verify() {
147 if (getModifier() != LoadCacheModifierKind::CG &&
148 getModifier() != LoadCacheModifierKind::CA)
149 return emitError(
"Only CG and CA cache modifiers are supported.");
150 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
151 return emitError(
"expected byte size to be either 4, 8 or 16.");
152 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
153 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
160 if (tensorDims < 1 || tensorDims > 5)
161 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
163 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
164 size_t expectedIm2colOff) -> LogicalResult {
165 if (isIm2col && (tensorDims < 3))
167 <<
"to use " << stringifyEnum(mode)
168 <<
" mode, the tensor has to be at least 3-dimensional";
170 if (numIm2colOff != expectedIm2colOff)
171 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
172 <<
" (provided " << numIm2colOff <<
")";
178 case TMALoadMode::TILE:
179 return checkTMALoadParams(mode,
false, 0);
180 case TMALoadMode::IM2COL:
181 return checkTMALoadParams(mode,
true, tensorDims - 2);
182 case TMALoadMode::IM2COL_W:
183 case TMALoadMode::IM2COL_W_128:
184 return checkTMALoadParams(mode,
true, 2);
185 case TMALoadMode::TILE_GATHER4:
186 return (tensorDims == 5)
187 ? checkTMALoadParams(mode,
false, 0)
188 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
193LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
195 getMode(), getLoc());
198LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
199 TMALoadMode mode = getMode();
200 bool isCTAOnly = getIsCTAOnly();
201 if (getPredicate()) {
203 return emitError(
"Predicate is supported only for shared::cluster mode.");
204 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
206 "Predicate is supported only for Tile and Im2col modes.");
208 NVVMMemorySpace expectedAS =
209 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
210 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
212 if (AS != expectedAS)
215 ?
"Shared::cta destination requires address-space 3."
216 :
"Shared::cluster destination requires address-space 7.");
219 if (getMulticastMask())
220 return emitError(
"Multicast is not supported with shared::cta mode.");
222 return emitError(
"CTAGroup is not supported with shared::cta mode.");
227 getMode(), getLoc());
230LogicalResult CpAsyncBulkTensorReduceOp::verify() {
231 TMAStoreMode mode = getMode();
234 case TMAStoreMode::TILE:
236 case TMAStoreMode::IM2COL:
238 case TMAStoreMode::TILE_SCATTER4:
239 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
244LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
246 if (isSharedCTA && getMulticastMask())
247 return emitError(
"Multicast is not supported with shared::cta mode.");
253 NVVM::MemScopeKind scope,
254 Value retVal =
nullptr) {
255 if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
256 return op->
emitError(
"mbarrier scope must be either CTA or Cluster");
259 bool hasRetValue =
static_cast<bool>(retVal);
260 if (isSharedCluster && hasRetValue)
262 "mbarrier in shared_cluster space cannot return any value");
267LogicalResult MBarrierArriveOp::verify() {
272LogicalResult MBarrierArriveDropOp::verify() {
277LogicalResult MBarrierArriveExpectTxOp::verify() {
281 if (getPredicate()) {
282 if (getScope() != NVVM::MemScopeKind::CTA)
283 return emitError(
"mbarrier scope must be CTA when using predicate");
286 return emitError(
"mbarrier in shared_cluster space is not supported when "
290 return emitError(
"return-value is not supported when using predicate");
292 if (getRelaxed() ==
true)
293 return emitError(
"mbarrier with relaxed semantics is not supported when "
300LogicalResult MBarrierArriveDropExpectTxOp::verify() {
305LogicalResult MBarrierExpectTxOp::verify() {
309LogicalResult MBarrierCompleteTxOp::verify() {
313LogicalResult MBarrierTestWaitOp::verify() {
317LogicalResult MBarrierTryWaitOp::verify() {
321LogicalResult ConvertFloatToTF32Op::verify() {
322 using RndMode = NVVM::FPRoundingMode;
326 return emitError(
"Relu not supported with rna rounding mode.");
333 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
338LogicalResult ConvertF32x2ToF6x2Op::verify() {
341 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
343 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
344 << mlir::Float6E3M2FNType::get(ctx)
345 <<
" types are supported for conversions from f32x2 to f6x2.";
350LogicalResult ConvertF32x2ToF8x2Op::verify() {
351 using RndMode = NVVM::FPRoundingMode;
352 using SatMode = NVVM::SaturationMode;
354 bool isRoundingModeRN = getRnd() == RndMode::RN;
355 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
356 bool isRoundingModeRP = getRnd() == RndMode::RP;
357 bool isSatFinite = getSat() == SatMode::SATFINITE;
359 bool hasRelu = getRelu();
364 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
366 if (!isRoundingModeRN) {
367 return emitOpError(
"Only RN rounding mode is supported for "
368 "conversions from f32x2 to ")
369 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
370 << mlir::Float8E5M2Type::get(ctx) <<
" types";
373 return emitOpError(
"Only SATFINITE saturation mode is supported "
376 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
377 << mlir::Float8E5M2Type::get(ctx) <<
" types";
381 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
382 if (!(isRoundingModeRZ || isRoundingModeRP)) {
383 return emitOpError(
"Only RZ and RP rounding modes are supported for "
384 "conversions from f32x2 to ")
385 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
388 return emitOpError(
"relu not supported for conversions to ")
389 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
395 << mlir::Float8E4M3FNType::get(ctx) <<
", "
396 << mlir::Float8E5M2Type::get(ctx) <<
", and "
397 << mlir::Float8E8M0FNUType::get(ctx)
399 "supported for conversions from f32x2 to f8x2";
403LogicalResult ConvertF16x2ToF8x2Op::verify() {
406 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
408 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
409 << mlir::Float8E5M2Type::get(ctx)
410 <<
" types are supported for conversions from f16x2 to f8x2.";
415LogicalResult ConvertBF16x2ToF8x2Op::verify() {
416 using RndMode = NVVM::FPRoundingMode;
418 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
420 <<
" type is supported for conversions from "
424 if (rnd != RndMode::RZ && rnd != RndMode::RP)
425 return emitOpError(
"Only RZ and RP rounding modes are supported for "
426 "conversions from bf16x2 to f8x2.");
431LogicalResult ConvertF32x2ToF4x2Op::verify() {
434 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
436 << mlir::Float4E2M1FNType::get(ctx)
437 <<
" type is supported for conversions from f32x2 to f4x2.";
442LogicalResult ConvertF8x2ToF16x2Op::verify() {
445 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
447 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
448 << mlir::Float8E5M2Type::get(ctx)
449 <<
" types are supported for conversions from f8x2 to f16x2.";
454LogicalResult ConvertF8x2ToBF16x2Op::verify() {
456 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
458 << mlir::Float8E8M0FNUType::get(ctx)
459 <<
" type is supported for conversions from f8x2 to bf16x2.";
464LogicalResult ConvertF6x2ToF16x2Op::verify() {
467 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
469 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
470 << mlir::Float6E3M2FNType::get(ctx)
471 <<
" types are supported for conversions from f6x2 to f16x2.";
476LogicalResult ConvertF4x2ToF16x2Op::verify() {
479 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
481 << mlir::Float4E2M1FNType::get(ctx)
482 <<
" type is supported for conversions from f4x2 to f16x2.";
487LogicalResult PermuteOp::verify() {
488 using Mode = NVVM::PermuteMode;
489 bool hasHi =
static_cast<bool>(getHi());
497 << stringifyPermuteMode(getMode()) <<
"' requires 'hi' operand.";
504 return emitError(
"mode '") << stringifyPermuteMode(getMode())
505 <<
"' does not accept 'hi' operand.";
520 static constexpr FPRoundingMode validRndModes[] = {
521 FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};
523 if (!llvm::is_contained(validRndModes, rnd)) {
525 "Only RN, RZ, and RS rounding modes are supported for "
526 "conversions from f32x2 to ")
530 if (rnd == FPRoundingMode::RS) {
531 if (!hasRandomBits) {
532 return op->
emitOpError(
"random_bits is required for RS rounding mode.");
537 "random_bits not supported for RN and RZ rounding modes.");
544LogicalResult ConvertF32x2ToF16x2Op::verify() {
546 getRandomBits() ?
true :
false, *
this);
549LogicalResult ConvertF32x2ToBF16x2Op::verify() {
551 getRandomBits() ?
true :
false, *
this);
554LogicalResult ConvertF32x4ToF8x4Op::verify() {
557 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
559 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
560 << mlir::Float8E5M2Type::get(ctx)
561 <<
" types are supported for conversions from f32x4 to f8x4.";
566LogicalResult ConvertF32x4ToF6x4Op::verify() {
569 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
571 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
572 << mlir::Float6E3M2FNType::get(ctx)
573 <<
" types are supported for conversions from f32x4 to f6x4.";
578LogicalResult ConvertF32x4ToF4x4Op::verify() {
581 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
582 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
583 <<
" type is supported for conversions from "
589LogicalResult BulkStoreOp::verify() {
590 if (getInitVal() != 0)
591 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
595LogicalResult PMEventOp::verify() {
596 auto eventId = getEventId();
597 auto maskedEventId = getMaskedEventId();
598 if (!maskedEventId && !eventId) {
599 return emitOpError() <<
"either `id` or `mask` must be set";
602 if (maskedEventId && eventId) {
603 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
607 if (eventId < 0 || eventId > 15) {
608 return emitOpError() <<
"`id` must be between 0 and 15";
612 return llvm::success();
618std::optional<mlir::NVVM::MMATypes>
619MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
621 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
622 if (operandElType.
isF64())
623 return NVVM::MMATypes::f64;
624 if (operandElType.
isF16() || operandElType == half2Type)
625 return NVVM::MMATypes::f16;
626 if (operandElType.
isF32() && isAccumulator)
627 return NVVM::MMATypes::f32;
628 if (operandElType.
isF32() && !isAccumulator)
629 return NVVM::MMATypes::tf32;
630 if (llvm::isa<IntegerType>(operandElType)) {
632 return NVVM::MMATypes::s32;
636 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
637 if (structType.getBody().empty())
639 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
646 return (type == MMATypes::u4 || type == MMATypes::s4);
650 return (type == MMATypes::u8 || type == MMATypes::s8);
655 type == MMATypes::s32;
658MMATypes MmaOp::accumPtxType() {
659 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
660 getODSOperands(2).getTypes().front(),
true);
661 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
665MMATypes MmaOp::resultPtxType() {
666 std::optional<mlir::NVVM::MMATypes> val =
667 inferOperandMMAType(getResult().
getType(),
true);
668 assert(val.has_value() &&
"result PTX type should always be inferrable");
674 struct MMAOperandFragment {
675 StringRef operandName;
676 StringRef ptxTypeAttr;
677 SmallVector<Value, 4> regs;
678 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
679 : operandName(name), ptxTypeAttr(ptxTypeName) {}
682 std::array<MMAOperandFragment, 3> frags{
683 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
684 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
685 MMAOperandFragment(
"C",
"")};
687 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
689 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
690 auto &frag = frags[fragIdx];
691 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
692 for (
auto operandIdx = varOperandSpec.first;
693 operandIdx < varOperandSpec.first + varOperandSpec.second;
695 frag.regs.push_back(this->getOperand(operandIdx));
696 if (operandIdx == 0) {
697 regTypes.push_back(this->getOperand(operandIdx).
getType());
700 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
701 regTypes.back(), fragIdx >= 2);
703 ignoreAttrNames.push_back(frag.ptxTypeAttr);
706 auto printMmaOperand = [&](
const MMAOperandFragment &frag) ->
void {
707 p <<
" " << frag.operandName;
713 for (
const auto &frag : frags) {
714 printMmaOperand(frag);
723 frags[1].regs[0].getType(),
724 frags[2].regs[0].getType()},
733 std::optional<MMAIntOverflow> intOverflow,
734 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
735 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
737 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
742 result.addOperands(operandA);
743 result.addOperands(operandB);
744 result.addOperands(operandC);
746 if (multiplicandPtxTypes) {
747 result.addAttribute(
"multiplicandAPtxType",
748 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
749 result.addAttribute(
"multiplicandBPtxType",
750 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
752 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
753 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
754 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
755 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
758 if (multiplicandLayouts) {
759 result.addAttribute(
"layoutA",
760 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
761 result.addAttribute(
"layoutB",
762 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
764 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
765 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
768 if (intOverflow.has_value())
769 result.addAttribute(
"intOverflowBehavior",
770 MMAIntOverflowAttr::get(ctx, *intOverflow));
771 if (b1Op.has_value())
772 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
774 result.addTypes(resultType);
776 MmaOp::getOperandSegmentSizeAttr(),
778 static_cast<int32_t>(operandB.size()),
779 static_cast<int32_t>(operandC.size())}));
787 struct MMAOperandFragment {
788 std::optional<MMATypes> elemtype;
789 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
790 SmallVector<Type> regTypes;
794 std::array<MMAOperandFragment, 4> frags;
800 MMAOperandFragment &frag) -> LogicalResult {
830 if (operandTypes.size() != 3)
833 "expected one type for each operand segment but got " +
834 Twine(operandTypes.size()) +
" types");
835 for (
const auto &iter : llvm::enumerate(operandTypes)) {
836 auto &frag = frags[iter.index()];
837 frag.regTypes.resize(frag.regs.size(), iter.value());
841 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
848 frags[3].elemtype = inferOperandMMAType(resultType,
true);
850 std::array<StringRef, 2> names{
"multiplicandAPtxType",
851 "multiplicandBPtxType"};
852 for (
unsigned idx = 0; idx < names.size(); idx++) {
853 const auto &frag = frags[idx];
854 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
855 if (!frag.elemtype.has_value() && !attr.has_value()) {
858 "attribute " + names[idx] +
859 " is not provided explicitly and cannot be inferred");
861 if (!attr.has_value())
863 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
866 result.addTypes(resultType);
867 if (!namedAttributes.
empty())
868 result.addAttributes(namedAttributes);
869 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
871 static_cast<int32_t>(frags[0].regs.size()),
872 static_cast<int32_t>(frags[1].regs.size()),
873 static_cast<int32_t>(frags[2].regs.size()),
878LogicalResult MmaOp::verify() {
880 auto f16Ty = Float16Type::get(context);
881 auto i32Ty = IntegerType::get(context, 32);
882 auto f16x2Ty = VectorType::get(2, f16Ty);
883 auto f32Ty = Float32Type::get(context);
884 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
885 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
888 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
891 auto f16x2x2StructTy =
892 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
894 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
896 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
898 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
899 getShapeAttr().getK()};
905 AllowedShapes allowedShapes;
906 AllowedTypes expectedA;
907 AllowedTypes expectedB;
908 AllowedTypes expectedC;
913 if (mmaShape[0] == 16) {
915 Type multiplicandFragType;
916 switch (*getMultiplicandAPtxType()) {
919 multiplicandFragType = i32Ty;
920 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
921 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
925 multiplicandFragType = i32Ty;
926 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
927 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
931 multiplicandFragType = f16x2Ty;
932 expectedResult.push_back(f16x2x2StructTy);
933 expectedResult.push_back(f32x4StructTy);
947 return emitError(
"invalid shape or multiplicand type: " +
948 stringifyEnum(getMultiplicandAPtxType().value()));
952 expectedResult.push_back(s32x4StructTy);
953 expectedC.emplace_back(4, i32Ty);
954 multiplicandFragType = i32Ty;
956 expectedC.emplace_back(2, f16x2Ty);
957 expectedC.emplace_back(4, f32Ty);
960 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
961 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
962 expectedA.emplace_back(unitA, multiplicandFragType);
963 expectedB.emplace_back(unitB, multiplicandFragType);
964 allowedShapes.push_back({16, 8, kFactor});
965 allowedShapes.push_back({16, 8, kFactor * 2});
967 if (resultPtxType() != accumPtxType())
972 if (mmaShape[0] == 8) {
973 if (*getMultiplicandAPtxType() == MMATypes::f16) {
974 expectedA.emplace_back(2, f16x2Ty);
975 expectedB.emplace_back(2, f16x2Ty);
976 expectedResult.push_back(f16x2x4StructTy);
977 expectedResult.push_back(f32x8StructTy);
978 expectedC.emplace_back(4, f16x2Ty);
979 expectedC.emplace_back(8, f32Ty);
980 allowedShapes.push_back({8, 8, 4});
982 if (*getMultiplicandAPtxType() == MMATypes::f64) {
983 Type f64Ty = Float64Type::get(context);
984 expectedA.emplace_back(1, f64Ty);
985 expectedB.emplace_back(1, f64Ty);
986 expectedC.emplace_back(2, f64Ty);
987 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
989 allowedShapes.push_back({8, 8, 4});
992 expectedA.push_back({i32Ty});
993 expectedB.push_back({i32Ty});
994 expectedC.push_back({i32Ty, i32Ty});
995 expectedResult.push_back(s32x2StructTy);
997 allowedShapes.push_back({8, 8, 32});
999 allowedShapes.push_back({8, 8, 16});
1000 if (getMultiplicandAPtxType().value() == MMATypes::b1)
1001 allowedShapes.push_back({8, 8, 128});
1005 std::string errorMessage;
1006 llvm::raw_string_ostream errorStream(errorMessage);
1009 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1010 !llvm::is_contained(allowedShapes, mmaShape)) {
1011 errorStream <<
"unimplemented variant for MMA shape <";
1012 llvm::interleaveComma(mmaShape, errorStream);
1018 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1019 for (
const auto &iter : llvm::enumerate(
1021 auto spec = this->getODSOperandIndexAndLength(iter.index());
1023 operand_type_begin() + spec.first +
1025 bool match = llvm::is_contained(iter.value(), operandTySeg);
1028 errorStream <<
"Could not match types for the "
1029 << operandNames[iter.index()]
1030 <<
" operands; expected one of ";
1031 for (
const auto &x : iter.value()) {
1032 errorStream << x.size() <<
"x" << x[0] <<
" ";
1034 errorStream <<
"but got ";
1035 llvm::interleaveComma(operandTySeg, errorStream);
1041 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1042 return expectedResultType == getResult().getType();
1045 <<
"Could not match allowed types for the result; expected one of ";
1046 llvm::interleaveComma(expectedResult, errorStream);
1047 errorStream <<
" but got " << getResult().getType();
1052 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
1053 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
1061 if (!getIntOverflowBehavior())
1063 getIntOverflowBehaviorAttrName().strref() +
1071 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
1072 getMultiplicandAPtxType() == MMATypes::f16);
1074 if (!isM8N8K4_F16) {
1076 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
1077 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
1078 "layoutB = #nvvm.mma_layout<col> for shape <")
1079 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
1080 <<
"> with element types "
1081 << stringifyEnum(*getMultiplicandAPtxType()) <<
" and "
1082 << stringifyEnum(*getMultiplicandBPtxType())
1083 <<
". Only m8n8k4 with f16 supports other layouts.";
1090MMATypes MmaSpOp::accumPtxType() {
1091 std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
1092 getODSOperands(2).getTypes().front(),
true);
1093 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
1097MMATypes MmaSpOp::resultPtxType() {
1098 std::optional<mlir::NVVM::MMATypes> val =
1099 MmaOp::inferOperandMMAType(getResult().
getType(),
true);
1100 assert(val.has_value() &&
"result PTX type should always be inferrable");
1106 llvm::IRBuilderBase &builder) {
1107 auto thisOp = cast<NVVM::MmaSpOp>(op);
1115 auto intId = MmaSpOp::getIntrinsicID(
1116 thisOp.getShape().getM(), thisOp.getShape().getN(),
1117 thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
1118 thisOp.getOrderedMetadata(), thisOp.getKind(),
1119 *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
1120 thisOp.accumPtxType(), thisOp.resultPtxType());
1122 return {intId, args};
1127 struct MMAOperandFragment {
1128 StringRef operandName;
1129 StringRef ptxTypeAttr;
1130 SmallVector<Value, 4> regs;
1131 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1132 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1135 std::array<MMAOperandFragment, 5> frags{
1136 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1137 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1138 MMAOperandFragment(
"C",
""), MMAOperandFragment(
"sparseMetadata",
""),
1139 MMAOperandFragment(
"selector",
"")};
1141 mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
1144 for (
unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
1145 auto &frag = frags[fragIdx];
1146 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
1147 for (
auto operandIdx = varOperandSpec.first;
1148 operandIdx < varOperandSpec.first + varOperandSpec.second;
1150 frag.regs.push_back(this->getOperand(operandIdx));
1151 if (operandIdx == varOperandSpec.first) {
1152 regTypes.push_back(this->getOperand(operandIdx).
getType());
1155 std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
1156 regTypes.back(), fragIdx >= 2);
1158 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1162 frags[3].regs.push_back(getSparseMetadata());
1163 frags[4].regs.push_back(getSparsitySelector());
1165 auto printMmaSpOperand = [&](
const MMAOperandFragment &frag) ->
void {
1166 p <<
" " << frag.operandName;
1172 for (
const auto &frag : frags)
1173 printMmaSpOperand(frag);
1178 for (
int i = 0; i < 3; ++i) {
1183 p <<
") -> " << getResult().getType();
1190 std::optional<MMAIntOverflow> intOverflow,
1191 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1193 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1198 result.addOperands(operandA);
1199 result.addOperands(operandB);
1200 result.addOperands(operandC);
1201 result.addOperands(sparseMetadata);
1202 result.addOperands(sparsitySelector);
1204 if (multiplicandPtxTypes) {
1205 result.addAttribute(
"multiplicandAPtxType",
1206 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1207 result.addAttribute(
"multiplicandBPtxType",
1208 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1210 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1211 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1212 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1213 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1216 if (intOverflow.has_value())
1217 result.addAttribute(
"intOverflowBehavior",
1218 MMAIntOverflowAttr::get(ctx, *intOverflow));
1220 result.addTypes(resultType);
1222 MmaSpOp::getOperandSegmentSizeAttr(),
1224 static_cast<int32_t>(operandB.size()),
1225 static_cast<int32_t>(operandC.size()), 1,
1230 struct MMAOperandFragment {
1231 std::optional<MMATypes> elemtype;
1232 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1233 SmallVector<Type> regTypes;
1237 std::array<MMAOperandFragment, 6> frags;
1242 auto parseMmaSpOperand = [&](StringRef operandName,
1243 MMAOperandFragment &frag) -> LogicalResult {
1254 if (parseMmaSpOperand(
"A", frags[0]).
failed())
1256 if (parseMmaSpOperand(
"B", frags[1]).
failed())
1258 if (parseMmaSpOperand(
"C", frags[2]).
failed())
1260 if (parseMmaSpOperand(
"sparseMetadata", frags[3]).
failed())
1262 if (parseMmaSpOperand(
"selector", frags[4]).
failed())
1278 if (operandTypes.size() != 3)
1281 "expected one type for each operand segment but got " +
1282 Twine(operandTypes.size()) +
" types");
1283 for (
const auto &iter : llvm::enumerate(operandTypes)) {
1284 auto &frag = frags[iter.index()];
1285 frag.regTypes.resize(frag.regs.size(), iter.value());
1290 MmaOp::inferOperandMMAType(frag.regTypes[0],
1298 MmaOp::inferOperandMMAType(resultType,
true);
1313 std::array<StringRef, 2> names{
"multiplicandAPtxType",
1314 "multiplicandBPtxType"};
1315 for (
unsigned idx = 0; idx < names.size(); idx++) {
1316 const auto &frag = frags[idx];
1317 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
1318 if (!frag.elemtype.has_value() && !attr.has_value()) {
1321 "attribute " + names[idx] +
1322 " is not provided explicitly and cannot be inferred");
1324 if (!attr.has_value())
1326 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
1329 result.addTypes(resultType);
1330 if (!namedAttributes.
empty())
1331 result.addAttributes(namedAttributes);
1332 result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(),
1334 static_cast<int32_t>(frags[0].regs.size()),
1335 static_cast<int32_t>(frags[1].regs.size()),
1336 static_cast<int32_t>(frags[2].regs.size()),
1343LogicalResult MmaSpOp::verify() {
1345 auto f16Ty = Float16Type::get(context);
1346 auto i32Ty = IntegerType::get(context, 32);
1347 auto f16x2Ty = VectorType::get(2, f16Ty);
1348 auto f32Ty = Float32Type::get(context);
1349 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
1350 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
1352 auto s32x4StructTy =
1353 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
1354 auto f32x8StructTy =
1356 auto f16x2x2StructTy =
1357 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
1358 auto f32x4StructTy =
1359 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
1360 auto s32x2StructTy =
1361 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
1363 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
1364 getShapeAttr().getK()};
1370 AllowedShapes allowedShapes;
1371 AllowedTypes expectedA;
1372 AllowedTypes expectedB;
1373 AllowedTypes expectedC;
1378 if (mmaShape[0] == 16) {
1380 Type multiplicandFragType;
1381 switch (*getMultiplicandAPtxType()) {
1382 case MMATypes::tf32:
1384 multiplicandFragType = i32Ty;
1385 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1386 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1388 allowedShapes.push_back({16, 8, 8});
1389 allowedShapes.push_back({16, 8, 16});
1391 case MMATypes::bf16:
1393 multiplicandFragType = i32Ty;
1394 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
1395 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
1397 allowedShapes.push_back({16, 8, 16});
1398 allowedShapes.push_back({16, 8, 32});
1402 multiplicandFragType = f16x2Ty;
1403 expectedResult.push_back(f16x2x2StructTy);
1404 expectedResult.push_back(f32x4StructTy);
1406 allowedShapes.push_back({16, 8, 16});
1407 allowedShapes.push_back({16, 8, 32});
1413 allowedShapes.push_back({16, 8, 64});
1414 allowedShapes.push_back({16, 8, 128});
1420 allowedShapes.push_back({16, 8, 32});
1421 allowedShapes.push_back({16, 8, 64});
1423 case MMATypes::e4m3:
1424 case MMATypes::e5m2:
1425 case MMATypes::e3m2:
1426 case MMATypes::e2m3:
1427 case MMATypes::e2m1:
1429 multiplicandFragType = i32Ty;
1430 expectedResult.push_back(f16x2x2StructTy);
1431 expectedResult.push_back(f32x4StructTy);
1433 allowedShapes.push_back({16, 8, 64});
1436 return emitError(
"invalid shape or multiplicand type: " +
1437 stringifyEnum(getMultiplicandAPtxType().value()));
1441 expectedResult.push_back(s32x4StructTy);
1442 expectedC.emplace_back(4, i32Ty);
1443 multiplicandFragType = i32Ty;
1444 }
else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 &&
1445 *getMultiplicandAPtxType() <= MMATypes::e2m1) {
1447 expectedC.emplace_back(2, f16x2Ty);
1448 expectedC.emplace_back(4, f32Ty);
1450 expectedC.emplace_back(2, f16x2Ty);
1451 expectedC.emplace_back(4, f32Ty);
1456 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
1457 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
1458 expectedA.emplace_back(unitA, multiplicandFragType);
1459 expectedB.emplace_back(unitB, multiplicandFragType);
1461 if (resultPtxType() != accumPtxType())
1466 if (mmaShape[0] == 8) {
1467 if (*getMultiplicandAPtxType() == MMATypes::f16) {
1468 expectedA.emplace_back(2, f16x2Ty);
1469 expectedB.emplace_back(2, f16x2Ty);
1470 expectedResult.push_back(f16x2x4StructTy);
1471 expectedResult.push_back(f32x8StructTy);
1472 expectedC.emplace_back(4, f16x2Ty);
1473 expectedC.emplace_back(8, f32Ty);
1474 allowedShapes.push_back({8, 8, 4});
1476 if (*getMultiplicandAPtxType() == MMATypes::f64) {
1477 Type f64Ty = Float64Type::get(context);
1478 expectedA.emplace_back(1, f64Ty);
1479 expectedB.emplace_back(1, f64Ty);
1480 expectedC.emplace_back(2, f64Ty);
1481 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
1483 allowedShapes.push_back({8, 8, 4});
1486 expectedA.push_back({i32Ty});
1487 expectedB.push_back({i32Ty});
1488 expectedC.push_back({i32Ty, i32Ty});
1489 expectedResult.push_back(s32x2StructTy);
1491 allowedShapes.push_back({8, 8, 32});
1493 allowedShapes.push_back({8, 8, 16});
1497 std::string errorMessage;
1498 llvm::raw_string_ostream errorStream(errorMessage);
1501 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
1502 !llvm::is_contained(allowedShapes, mmaShape)) {
1503 errorStream <<
"unimplemented variant for MMA shape <";
1504 llvm::interleaveComma(mmaShape, errorStream);
1510 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
1511 for (
const auto &iter : llvm::enumerate(
1513 auto spec = this->getODSOperandIndexAndLength(iter.index());
1515 operand_type_begin() + spec.first +
1517 bool match = llvm::is_contained(iter.value(), operandTySeg);
1520 errorStream <<
"Could not match types for the "
1521 << operandNames[iter.index()]
1522 <<
" operands; expected one of ";
1523 for (
const auto &x : iter.value()) {
1524 errorStream << x.size() <<
"x" << x[0] <<
" ";
1526 errorStream <<
"but got ";
1527 llvm::interleaveComma(operandTySeg, errorStream);
1533 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
1534 return expectedResultType == getResult().getType();
1537 <<
"Could not match allowed types for the result; expected one of ";
1538 llvm::interleaveComma(expectedResult, errorStream);
1539 errorStream <<
" but got " << getResult().getType();
1547 if (!getIntOverflowBehavior())
1549 getIntOverflowBehaviorAttrName().strref() +
1554 if (!getSparseMetadata().
getType().isInteger(32)) {
1555 return emitOpError() <<
"sparse metadata must be i32 type";
1559 if (!getSparsitySelector().
getType().isInteger(32)) {
1560 return emitOpError() <<
"sparsity selector must be i32 type";
1572struct MMAOperandFragment {
1573 StringRef operandName;
1574 StringRef ptxTypeAttr;
1575 SmallVector<Value, 4> regs;
1576 explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
1577 : operandName(name), ptxTypeAttr(ptxTypeName) {}
1584 p <<
" " << name <<
"[";
1603template <
typename Op>
1608 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
1609 auto &frag = frags[fragIdx];
1610 auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
1611 for (
auto operandIdx = varOperandSpec.first;
1612 operandIdx < varOperandSpec.first + varOperandSpec.second;
1614 frag.regs.push_back(op.getOperand(operandIdx));
1615 if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
1616 regTypes.push_back(op.getOperand(operandIdx).getType());
1620 regTypes.push_back(frag.regs[0].getType());
1622 std::optional<MMATypes> inferredType =
1623 MmaOp::inferOperandMMAType(regTypes.back(),
1626 ignoreAttrNames.push_back(frag.ptxTypeAttr);
1637 auto typeParser = [&]() {
1641 operandTypes.push_back(ty);
1647 if (operandTypes.size() != 3)
1649 "expected exactly 3 types");
1658 if (!attrs.
get(
"multiplicandAPtxType")) {
1659 if (
auto inferredType =
1660 MmaOp::inferOperandMMAType(operandTypes[0],
false)) {
1661 attrs.
set(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
1664 if (!attrs.
get(
"multiplicandBPtxType")) {
1665 if (
auto inferredType =
1666 MmaOp::inferOperandMMAType(operandTypes[1],
false)) {
1667 attrs.
set(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
1673template <
typename OpType>
1676 ScaleVecSize scaleVecSize,
1677 BlockScaleFormat blockScaleFormat,
1678 MMABlockScaleKind kind) {
1680 auto &properties =
result.getOrAddProperties<
typename OpType::Properties>();
1681 properties.setShape(
1683 properties.setScaleVecSize(ScaleVecSizeAttr::get(ctx, scaleVecSize));
1684 properties.setBlockScaleFormat(
1685 BlockScaleFormatAttr::get(ctx, blockScaleFormat));
1686 properties.setKind(MMABlockScaleKindAttr::get(ctx, kind));
1693 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
1694 if (multiplicandPtxTypes) {
1695 result.addAttribute(
"multiplicandAPtxType",
1696 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
1697 result.addAttribute(
"multiplicandBPtxType",
1698 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
1700 if (
auto res = MmaOp::inferOperandMMAType(operandA[0].
getType(),
false))
1701 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
1702 if (
auto res = MmaOp::inferOperandMMAType(operandB[0].
getType(),
false))
1703 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
1708template <
typename OpTy>
1710 return *MmaOp::inferOperandMMAType(
1711 cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
1721 std::array<MMAOperandFragment, 3> frags{
1722 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1723 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1724 MMAOperandFragment(
"C",
"")};
1726 mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
1731 for (
const auto &frag : frags)
1736 {getScaleAData(), getByteIdA(), getThreadIdA()});
1738 {getScaleBData(), getByteIdB(), getThreadIdB()});
1745 frags[1].regs[0].getType(),
1746 frags[2].regs[0].getType()},
1752ParseResult MmaBlockScaleOp::parse(
OpAsmParser &parser,
1754 struct LocalOperandFragment {
1755 std::optional<MMATypes> elemtype;
1756 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1760 std::array<LocalOperandFragment, 3> frags;
1789 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
1790 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
1793 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
1803 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
1813 result.addAttributes(namedAttributes);
1817 result.addTypes(resultTypes);
1818 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1820 static_cast<int32_t>(frags[0].regs.size()),
1821 static_cast<int32_t>(frags[1].regs.size()),
1822 static_cast<int32_t>(frags[2].regs.size()),
1833void MmaBlockScaleOp::build(
1838 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
1839 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
1840 MMABlockScaleKind kind) {
1841 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
1844 blockScaleFormat, kind);
1846 result.addOperands(operandA);
1847 result.addOperands(operandB);
1848 result.addOperands(operandC);
1850 {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
1853 multiplicandPtxTypes);
1855 result.addTypes(resultType);
1856 result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
1858 static_cast<int32_t>(operandA.size()),
1859 static_cast<int32_t>(operandB.size()),
1860 static_cast<int32_t>(operandC.size()),
1872 auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
1876 for (
Value operand : curOp.getOperandA())
1878 for (
Value operand : curOp.getOperandB())
1880 for (
Value operand : curOp.getOperandC())
1884 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
1885 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
1886 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
1887 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
1888 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
1889 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
1891 unsigned intId = MmaBlockScaleOp::getIntrinsicID(
1892 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
1893 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
1895 curOp.getBlockScaleFormat(), curOp.getKind());
1897 return {intId, args};
1900LogicalResult MmaBlockScaleOp::verify() {
1906 if (m == 16 && n == 8 && k == 64) {
1907 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
1908 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
1910 "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
1911 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
1912 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
1914 "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
1915 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
1917 "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
1918 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
1919 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
1920 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
1921 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
1922 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
1924 "attributes for mma.m16n8k64.mxf4nvf4");
1928 }
else if (m == 16 && n == 8 && k == 32) {
1929 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
1930 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
1931 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
1933 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
1934 "attributes for mma.m16n8k32");
1947 std::array<MMAOperandFragment, 3> frags{
1948 MMAOperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
1949 MMAOperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
1950 MMAOperandFragment(
"C",
"")};
1952 mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
1957 for (
const auto &frag : frags)
1966 {getScaleAData(), getByteIdA(), getThreadIdA()});
1968 {getScaleBData(), getByteIdB(), getThreadIdB()});
1975 frags[1].regs[0].getType(),
1976 frags[2].regs[0].getType()},
1982ParseResult MmaSpBlockScaleOp::parse(
OpAsmParser &parser,
1984 struct LocalOperandFragment {
1985 std::optional<MMATypes> elemtype;
1986 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
1990 std::array<LocalOperandFragment, 3> frags;
2026 for (
const auto &[idx, frag] : llvm::enumerate(frags)) {
2027 frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
2030 .resolveOperands(frag.regs, operandTypes[idx], parser.
getNameLoc(),
2039 .resolveOperands(metadataOperands, i32Type, parser.
getNameLoc(),
2052 .resolveOperands(scaleAOperands, scaleTypes, parser.
getNameLoc(),
2062 result.addAttributes(namedAttributes);
2067 if (!
result.attributes.get(
"orderedMetadata"))
2070 result.addTypes(resultTypes);
2071 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2073 static_cast<int32_t>(frags[0].regs.size()),
2074 static_cast<int32_t>(frags[1].regs.size()),
2075 static_cast<int32_t>(frags[2].regs.size()),
2088void MmaSpBlockScaleOp::build(
2094 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
2095 ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
2096 MMABlockScaleKind kind) {
2097 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
2100 builder,
result,
shape, scaleVecSize, blockScaleFormat, kind);
2103 result.addOperands(operandA);
2104 result.addOperands(operandB);
2105 result.addOperands(operandC);
2106 result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
2107 threadIdA, scaleBData, byteIdB, threadIdB});
2110 multiplicandPtxTypes);
2112 result.addTypes(resultType);
2113 result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
2115 static_cast<int32_t>(operandA.size()),
2116 static_cast<int32_t>(operandB.size()),
2117 static_cast<int32_t>(operandC.size()),
2131 auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
2135 for (
Value operand : curOp.getOperandA())
2137 for (
Value operand : curOp.getOperandB())
2139 for (
Value operand : curOp.getOperandC())
2143 args.push_back(mt.
lookupValue(curOp.getSparseMetadata()));
2144 args.push_back(mt.
lookupValue(curOp.getSparsitySelector()));
2147 args.push_back(mt.
lookupValue(curOp.getScaleAData()));
2148 args.push_back(mt.
lookupValue(curOp.getByteIdA()));
2149 args.push_back(mt.
lookupValue(curOp.getThreadIdA()));
2150 args.push_back(mt.
lookupValue(curOp.getScaleBData()));
2151 args.push_back(mt.
lookupValue(curOp.getByteIdB()));
2152 args.push_back(mt.
lookupValue(curOp.getThreadIdB()));
2154 unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
2155 curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
2156 *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
2158 curOp.getBlockScaleFormat(), curOp.getKind());
2160 return {intId, args};
2163LogicalResult MmaSpBlockScaleOp::verify() {
2165 if (!getOrderedMetadata()) {
2166 return emitOpError(
"'orderedMetadata' attribute is mandatory");
2174 if (m == 16 && n == 8 && k == 128) {
2175 if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
2176 getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
2178 "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
2179 if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
2180 if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
2182 "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
2183 if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
2185 "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
2186 }
else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
2187 if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
2188 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
2189 (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
2190 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
2192 "attributes for mma.m16n8k128.mxf4nvf4");
2196 }
else if (m == 16 && n == 8 && k == 64) {
2197 if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
2198 getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
2199 getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
2201 emitOpError(
"unsupported Kind, ScaleVecSize and BlockScaleFormat "
2202 "attributes for mma.m16n8k64");
2209LogicalResult ShflOp::verify() {
2210 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2212 auto verifyTypeError = [&](Twine desc,
Type expectedType,
2213 Type actualType) -> LogicalResult {
2214 return emitOpError(
"expected " + desc +
" to be of type ")
2215 << expectedType <<
" but got " << actualType <<
" instead";
2218 if (returnStructType) {
2219 if (!getReturnValueAndIsValid())
2220 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
2221 "specified when the return type is a struct type");
2223 if (returnStructType.getBody().size() != 2)
2224 return emitOpError(
"expected return type to be a two-element struct");
2227 auto resultType = returnStruct[0];
2228 if (resultType != getVal().
getType())
2229 return verifyTypeError(
"first element in the returned struct",
2230 getVal().
getType(), resultType);
2232 auto predicateType = returnStruct[1];
2233 if (!predicateType.isInteger(1))
2234 return verifyTypeError(
"second element in the returned struct",
2238 if (getReturnValueAndIsValid())
2239 return emitOpError(
"expected return type to be a two-element struct");
2242 return verifyTypeError(
"return type", getVal().
getType(),
getType());
2248 NVVM::MMAFrag frag,
int nRow,
2251 unsigned numberElements = 0;
2254 Type f16x2 = VectorType::get(2, builder.getF16Type());
2255 if (type == NVVM::MMATypes::f16) {
2256 elementType = f16x2;
2257 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2261 }
else if (type == NVVM::MMATypes::f32) {
2262 elementType = builder.getF32Type();
2264 }
else if (type == NVVM::MMATypes::f64) {
2265 elementType = builder.getF64Type();
2266 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
2270 }
else if (type == NVVM::MMATypes::tf32) {
2271 elementType = builder.getI32Type();
2273 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
2274 elementType = builder.getI32Type();
2275 int parallelSize = 0;
2276 if (frag == NVVM::MMAFrag::a)
2277 parallelSize = nRow;
2278 if (frag == NVVM::MMAFrag::b)
2279 parallelSize = nCol;
2282 if (parallelSize == 16)
2285 else if (parallelSize == 8)
2287 else if (parallelSize == 32)
2289 }
else if (type == NVVM::MMATypes::s32) {
2290 elementType = builder.getI32Type();
2293 assert(numberElements != 0 && elementType !=
nullptr);
2294 return std::make_pair(elementType, numberElements);
2297static std::pair<mlir::Type, unsigned>
2301 if (frag == NVVM::MMAFrag::a) {
2304 }
else if (frag == NVVM::MMAFrag::b) {
2311 assert(nRow && nCol);
2315LogicalResult NVVM::WMMALoadOp::verify() {
2316 unsigned addressSpace =
2317 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2318 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2319 addressSpace != NVVMMemorySpace::Shared)
2320 return emitOpError(
"expected source pointer in memory "
2323 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2324 getEltype(), getFrag()) == 0)
2325 return emitOpError() <<
"invalid attribute combination";
2330 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
2332 return emitOpError(
"expected destination type to be f64");
2336 Type dstType = LLVM::LLVMStructType::getLiteral(
2339 return emitOpError(
"expected destination type is a structure of ")
2340 << typeInfo.second <<
" elements of type " << typeInfo.first;
2344LogicalResult NVVM::WMMAStoreOp::verify() {
2345 unsigned addressSpace =
2346 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
2347 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
2348 addressSpace != NVVMMemorySpace::Shared)
2349 return emitOpError(
"expected operands to be a source pointer in memory "
2352 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
2354 return emitOpError() <<
"invalid attribute combination";
2357 if (getArgs().size() != typeInfo.second)
2358 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
2359 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
2360 return operands.
getType() != typeInfo.first;
2362 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
2366LogicalResult NVVM::WMMAMmaOp::verify() {
2367 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
2368 getLayoutB(), getEltypeA(),
2370 return emitOpError() <<
"invalid attribute combination";
2378 arguments.append(typeInfoA.second, typeInfoA.first);
2379 arguments.append(typeInfoB.second, typeInfoB.first);
2380 arguments.append(typeInfoC.second, typeInfoC.first);
2381 unsigned numArgs = arguments.size();
2382 if (getArgs().size() != numArgs)
2383 return emitOpError() <<
"expected " << numArgs <<
" arguments";
2384 for (
unsigned i = 0; i < numArgs; i++) {
2385 if (getArgs()[i].
getType() != arguments[i])
2386 return emitOpError() <<
"expected argument " << i <<
" to be of type "
2389 Type dstType = LLVM::LLVMStructType::getLiteral(
2392 return emitOpError(
"expected destination type is a structure of ")
2393 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
2397LogicalResult NVVM::LdMatrixOp::verify() {
2399 if (m == 8 && n == 8) {
2400 if (num != 1 && num != 2 && num != 4) {
2401 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
2404 if (getEltType() != LdStMatrixEltType::B16) {
2405 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
2407 }
else if (m == 8 && n == 16) {
2408 if (num != 1 && num != 2 && num != 4) {
2409 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
2412 if (getLayout() != MMALayout::row) {
2413 return emitOpError(
"expected layout to be row for 8x16 matrix");
2415 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2416 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2417 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
2418 "b8x16.b6x16_p32 for 8x16 matrix");
2420 }
else if (m == 16 && n == 16) {
2421 if (num != 1 && num != 2) {
2422 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
2425 if (getLayout() != MMALayout::col) {
2426 return emitOpError(
"expected layout to be col for 16x16 matrix");
2428 if (getEltType() != LdStMatrixEltType::B8 &&
2429 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
2430 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
2431 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
2432 "b8x16.b6x16_p32 for 16x16 matrix");
2435 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
2439 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
2440 if (numElements == 1 &&
getType() != i32)
2441 return emitOpError(
"expected destination type is i32");
2442 if (numElements == 2 || numElements == 4) {
2443 Type dstType = LLVM::LLVMStructType::getLiteral(
2446 return emitOpError(
"expected destination type is a structure of ")
2447 << numElements <<
" elements of type i32";
2453LogicalResult NVVM::StMatrixOp::verify() {
2454 int numMatrix = getSources().size();
2455 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
2456 return emitOpError(
"expected num attribute to be 1, 2 or 4");
2459 if (m == 8 && n == 8) {
2460 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
2461 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
2463 }
else if (m == 16 && n == 8) {
2464 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
2465 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
2467 if (getLayout() != NVVM::MMALayout::col) {
2468 return emitOpError(
"expected layout to be col for 16x8 matrix");
2471 return emitOpError(
"expected shape to be 8x8 or 16x8");
2478 if (typeA == NVVM::WGMMATypes::tf32)
2480 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
2482 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
2484 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
2486 if (typeA == NVVM::WGMMATypes::b1)
2492 NVVM::WGMMATypes typeA,
2493 NVVM::WGMMATypes typeB) {
2495 case NVVM::WGMMATypes::f16:
2496 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2497 typeB == NVVM::WGMMATypes::f16)
2500 case NVVM::WGMMATypes::tf32:
2501 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
2504 case NVVM::WGMMATypes::u8:
2505 case NVVM::WGMMATypes::s8:
2506 if (typeD == NVVM::WGMMATypes::s32 &&
2507 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
2510 case NVVM::WGMMATypes::b1:
2511 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
2514 case NVVM::WGMMATypes::bf16:
2515 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2516 typeB == NVVM::WGMMATypes::bf16)
2519 case NVVM::WGMMATypes::e4m3:
2520 case NVVM::WGMMATypes::e5m2:
2521 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
2522 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
2525 case WGMMATypes::f32:
2526 case WGMMATypes::s32:
2527 llvm_unreachable(
"unsupported input types");
2535 72, 80, 88, 96, 104, 112, 120, 128,
2536 136, 144, 152, 160, 168, 176, 184, 192,
2537 200, 208, 216, 224, 232, 240, 248, 256};
2539 80, 96, 112, 128, 144, 160,
2540 176, 192, 208, 224, 240, 256};
2542 case WGMMATypes::f16:
2543 case WGMMATypes::tf32:
2544 case WGMMATypes::bf16:
2545 case WGMMATypes::e4m3:
2546 case WGMMATypes::e5m2:
2547 if (llvm::is_contained(allowedN, sizeN))
2550 case WGMMATypes::u8:
2551 case WGMMATypes::s8:
2552 case WGMMATypes::b1:
2553 if (llvm::is_contained(allowedNshort, sizeN))
2556 case WGMMATypes::f32:
2557 case WGMMATypes::s32:
2558 llvm_unreachable(
"unsupported input types");
2564LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
2565 Value outValue = getResults();
2566 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
2568 return emitOpError() <<
"expected results to be struct";
2569 int outputSize = stype.getBody().size();
2570 WGMMATypes typeD = getTypeD();
2571 WGMMATypes typeA = getTypeA();
2572 WGMMATypes typeB = getTypeB();
2574 for (
Type t : stype.getBody()) {
2575 if (t != stype.getBody().front())
2577 <<
"all elements in struct must be same type but there is " << t;
2580 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
2581 typeD != WGMMATypes::s32) {
2582 return emitOpError() <<
"does not support the given output type "
2583 << NVVM::stringifyWGMMATypes(typeD);
2585 if (typeD == WGMMATypes::s32 &&
2586 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
2587 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
2591 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
2592 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
2593 << NVVM::stringifyWGMMATypes(typeB)
2594 <<
", it is not supported.";
2604 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
2605 <<
" for input type "
2606 << NVVM::stringifyWGMMATypes(typeA);
2611 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
2612 <<
getShape().getN() <<
", it is not supported.";
2619 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
2620 (getLayoutA() == mlir::NVVM::MMALayout::col ||
2621 getLayoutB() == mlir::NVVM::MMALayout::row)) {
2623 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
2624 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
2625 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
2626 << stringifyWGMMATypes(typeB)
2627 <<
" requires transpose. However, this is only supported for: "
2628 << stringifyMMATypes(MMATypes::f16) <<
" and "
2629 << stringifyMMATypes(MMATypes::bf16);
2633 int expectedOutput = 0;
2634 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
2635 expectedOutput =
getShape().getN() / 2;
2636 if (typeD == WGMMATypes::f16)
2637 expectedOutput =
getShape().getN() / 4;
2638 if (outputSize != expectedOutput) {
2639 return emitOpError() <<
"results " << expectedOutput
2640 <<
", however output struct has " << outputSize
2644 if (typeD != WGMMATypes::s32 &&
2645 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2646 NVVM::MMAIntOverflow::satfinite) {
2648 <<
" `satfinite` can be only used with s32 accumulator, however "
2649 "the current accumulator is "
2650 << NVVM::stringifyWGMMATypes(typeD);
2656std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
2659 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2661 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
2663 int expectedOutputRegisters = 0;
2664 if (getTypeD() == WGMMATypes::f16)
2665 expectedOutputRegisters =
getShape().getN() / 4;
2667 expectedOutputRegisters =
getShape().getN() / 2;
2670 llvm::raw_string_ostream ss(ptx);
2675 << ((expectedOutputRegisters * 2) + 2)
2677 "wgmma.mma_async.sync.aligned.m"
2678 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
2679 << stringifyWGMMATypes(getTypeA()) <<
"."
2680 << stringifyWGMMATypes(getTypeB());
2681 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
2682 NVVM::MMAIntOverflow::satfinite)
2686 for (; regCnt < expectedOutputRegisters; ++regCnt) {
2687 ss <<
"$" << regCnt;
2688 if (regCnt != expectedOutputRegisters - 1)
2694 regCnt = (regCnt * 2);
2695 ss <<
" $" << (regCnt) <<
","
2696 <<
" $" << (regCnt + 1) <<
","
2698 if (getTypeD() != WGMMATypes::s32) {
2699 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
2703 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
2710bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
2714 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
2721 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
2723 if (getTypeD() != WGMMATypes::s32) {
2724 asmValues.push_back(
2725 {makeConstantI32(rewriter,
2726 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2728 asmValues.push_back(
2729 {makeConstantI32(rewriter,
2730 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
2734 asmValues.push_back(
2735 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
2737 asmValues.push_back(
2738 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
2744LogicalResult NVVM::FenceSyncRestrictOp::verify() {
2745 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2746 getOrder() != NVVM::MemOrderKind::RELEASE)
2747 return emitOpError(
"only acquire and release semantics are supported");
2751LogicalResult NVVM::FenceProxyOp::verify() {
2752 if (getKind() == NVVM::ProxyKind::TENSORMAP)
2753 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
2754 if (getKind() == NVVM::ProxyKind::GENERIC)
2755 return emitOpError() <<
"generic proxy not a supported proxy kind";
2756 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
2757 return emitOpError() <<
"async_shared fence requires space attribute";
2759 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
2760 return emitOpError() <<
"only async_shared fence can have space attribute";
2765LogicalResult NVVM::FenceProxyAcquireOp::verify() {
2766 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2767 return emitOpError(
"uni-directional proxies only support generic for "
2768 "from_proxy attribute");
2770 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2771 return emitOpError(
"uni-directional proxies only support tensormap "
2772 "for to_proxy attribute");
2776LogicalResult NVVM::FenceProxyReleaseOp::verify() {
2777 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2778 return emitOpError(
"uni-directional proxies only support generic for "
2779 "from_proxy attribute");
2781 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
2782 return emitOpError(
"uni-directional proxies only support tensormap "
2783 "for to_proxy attribute");
2787LogicalResult NVVM::FenceProxySyncRestrictOp::verify() {
2788 if (getOrder() != NVVM::MemOrderKind::ACQUIRE &&
2789 getOrder() != NVVM::MemOrderKind::RELEASE)
2790 return emitOpError(
"only acquire and release semantics are supported");
2792 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
2793 return emitOpError(
"only generic is support for from_proxy attribute");
2795 if (getToProxy() != NVVM::ProxyKind::async)
2796 return emitOpError(
"only async is supported for to_proxy attribute");
2800LogicalResult NVVM::SetMaxRegisterOp::verify() {
2801 if (getRegCount() % 8)
2802 return emitOpError(
"new register size must be multiple of 8");
2803 if (getRegCount() < 24 || getRegCount() > 256)
2804 return emitOpError(
"new register size must be in between 24 to 256");
2808LogicalResult NVVM::BarrierOp::verify() {
2809 if (getNumberOfThreads() && !getBarrierId())
2811 "barrier id is missing, it should be set between 0 to 15");
2813 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
2814 return emitOpError(
"reduction are only available when id is 0");
2818 return emitOpError(
"reduction predicate and reduction operation must be "
2819 "specified together");
2824LogicalResult NVVM::Tcgen05CpOp::verify() {
2825 auto mc = getMulticast();
2827 using SH = Tcgen05CpShape;
2828 using MC = Tcgen05CpMulticast;
2830 case SH::SHAPE_128x256b:
2831 case SH::SHAPE_128x128b:
2832 case SH::SHAPE_4x256b:
2834 return emitError(
"Invalid multicast type for tcgen05.cp Op");
2836 case SH::SHAPE_64x128b:
2837 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
2838 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
2839 "warpx2_02_13 for tcgen05.cp Op");
2841 case SH::SHAPE_32x128b:
2842 if (mc != MC::WARPX4)
2844 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
2850LogicalResult NVVM::MatchSyncOp::verify() {
2851 if (getKind() == NVVM::MatchSyncKind::all) {
2852 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
2853 if (!type || type.getBody().size() != 2 ||
2854 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
2855 return emitOpError(
"match.sync 'all' returns a two element struct with "
2856 "first element as i32 and second element as i1");
2859 if (!
getType().isInteger(32)) {
2860 return emitOpError(
"match.sync 'any' returns an i32");
2866LogicalResult NVVM::VoteSyncOp::verify() {
2867 if (getKind() == NVVM::VoteSyncKind::ballot) {
2868 if (!
getType().isInteger(32)) {
2869 return emitOpError(
"vote.sync 'ballot' returns an i32");
2872 if (!
getType().isInteger(1)) {
2873 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
2879LogicalResult NVVM::PrefetchOp::verify() {
2880 using MemSpace = NVVM::NVVMMemorySpace;
2881 using CacheLevel = NVVM::PrefetchCacheLevel;
2883 unsigned addressSpace =
2884 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
2885 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
2886 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
2888 if (getTensormap() && cacheLevel)
2889 return emitOpError(
"cannot specify both tensormap and cache level");
2891 if (getTensormap()) {
2892 if (addressSpace != MemSpace::Generic &&
2893 addressSpace != MemSpace::Constant) {
2895 "prefetch tensormap requires a generic or constant pointer");
2898 if (evictPriority) {
2900 "prefetch tensormap does not support eviction priority");
2903 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
2905 "in_param_space can only be specified for a generic pointer");
2908 }
else if (cacheLevel) {
2909 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
2910 addressSpace != MemSpace::Local) {
2911 return emitOpError(
"prefetch to cache level requires a generic, global, "
2912 "or local pointer");
2916 if (*cacheLevel != CacheLevel::L1) {
2918 "unsupported cache level, the only supported uniform "
2919 "cache level is L1");
2922 if (addressSpace != MemSpace::Generic) {
2924 "prefetch to uniform cache requires a generic pointer");
2928 if (evictPriority) {
2929 if (*cacheLevel != CacheLevel::L2)
2931 "cache eviction priority supported only for cache level L2");
2933 if (addressSpace != MemSpace::Global)
2934 return emitOpError(
"cache eviction priority requires a global pointer");
2936 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
2937 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
2939 "unsupported cache eviction priority, only evict_last and "
2940 "evict_normal are supported");
2944 return emitOpError(
"predicate supported only on prefetch tensormap");
2948 "requires specification of either cache level or tensormap");
2954LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
2955 switch (getQueryType()) {
2956 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2958 return emitOpError(
"is_canceled query type returns an i1");
2960 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2961 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2962 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2963 if (!
getType().isInteger(32)) {
2964 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
2965 "get_first_cta_id_z query types return an i32");
2972LogicalResult NVVM::ReduxOp::verify() {
2975 if (!reduxType.
isF32()) {
2977 return emitOpError(
"abs attribute is supported only for f32 type");
2979 return emitOpError(
"nan attribute is supported only for f32 type");
2982 NVVM::ReduxKind kind = getKind();
2984 case NVVM::ReduxKind::ADD:
2985 case NVVM::ReduxKind::AND:
2986 case NVVM::ReduxKind::OR:
2987 case NVVM::ReduxKind::XOR:
2988 case NVVM::ReduxKind::MAX:
2989 case NVVM::ReduxKind::MIN:
2990 case NVVM::ReduxKind::UMAX:
2991 case NVVM::ReduxKind::UMIN:
2994 << stringifyEnum(kind) <<
"' redux kind unsupported with "
2995 << reduxType <<
" type. Only supported type is 'i32'.";
2997 case NVVM::ReduxKind::FMIN:
2998 case NVVM::ReduxKind::FMAX:
2999 if (!reduxType.
isF32())
3001 << stringifyEnum(kind) <<
"' redux kind unsupported with "
3002 << reduxType <<
" type. Only supported type is 'f32'.";
3015 unsigned sizeInBits,
3017 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
3019 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
3020 if (mask != 0xffffffffu)
3021 field = builder.CreateAnd(field, builder.getInt32(mask));
3023 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
3024 field = builder.CreateShl(field, start);
3026 return builder.CreateOr(
result, field);
3029void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
3031 llvm::IRBuilderBase &builder) {
3032 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
3033 llvm::Value *smemDesc = builder.getInt64(0);
3038 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
3040 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
3046 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
3050 mt.
mapValue(thisOp.getRes()) = smemDesc;
3057std::string NVVM::MBarrierInitOp::getPtx() {
3059 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
3060 : std::string(
"mbarrier.init.b64 [%0], %1;");
3063std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
3066 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
3067 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
3070std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
3072 llvm::StringRef space = isShared ?
".shared" :
"";
3074 return llvm::formatv(
"{\n\t"
3075 ".reg .pred P1; \n\t"
3077 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
3078 "@P1 bra.uni DONE; \n\t"
3079 "bra.uni LAB_WAIT; \n\t"
3091 auto thisOp = cast<NVVM::BarrierOp>(op);
3092 llvm::Value *barrierId = thisOp.getBarrierId()
3094 : builder.getInt32(0);
3095 llvm::Intrinsic::ID id;
3097 if (thisOp.getNumberOfThreads()) {
3098 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
3099 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
3100 }
else if (thisOp.getReductionOp()) {
3101 switch (*thisOp.getReductionOp()) {
3102 case NVVM::BarrierReduction::AND:
3103 id = llvm::Intrinsic::nvvm_barrier_cta_red_and_aligned_all;
3105 case NVVM::BarrierReduction::OR:
3106 id = llvm::Intrinsic::nvvm_barrier_cta_red_or_aligned_all;
3108 case NVVM::BarrierReduction::POPC:
3109 id = llvm::Intrinsic::nvvm_barrier_cta_red_popc_aligned_all;
3112 args.push_back(builder.CreateICmpNE(
3113 mt.
lookupValue(thisOp.getReductionPredicate()), builder.getInt32(0)));
3115 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
3118 return {id, std::move(args)};
3123 llvm::IRBuilderBase &builder) {
3124 auto thisOp = cast<NVVM::PMEventOp>(op);
3128 llvm::Value *maskVal;
3129 if (
auto eventAttr = thisOp.getEventIdAttr()) {
3130 uint16_t mask =
static_cast<uint16_t
>(1u << eventAttr.getInt());
3131 maskVal = llvm::ConstantInt::get(i16Ty, mask);
3134 llvm::ConstantInt::get(i16Ty, thisOp.getMaskedEventIdAttr().getValue());
3137 return {llvm::Intrinsic::nvvm_pm_event_mask, {maskVal}};
3142 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
3144 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
3145 : llvm::Intrinsic::nvvm_mbarrier_init;
3150 args.push_back(mt.
lookupValue(thisOp.getCount()));
3152 return {id, std::move(args)};
3157 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
3159 llvm::Intrinsic::ID
id = isShared
3160 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
3161 : llvm::Intrinsic::nvvm_mbarrier_inval;
3168 auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op);
3171 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3174 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3176 static constexpr llvm::Intrinsic::ID IDs[] = {
3177 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta,
3178 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster,
3179 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta,
3180 llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster};
3185 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3187 return {IDs[
index], std::move(args)};
3192 auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op);
3195 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3198 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3200 static constexpr llvm::Intrinsic::ID IDs[] = {
3201 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta,
3202 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster,
3203 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta,
3204 llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster};
3209 args.push_back(mt.
lookupValue(thisOp.getTxcount()));
3211 return {IDs[
index], std::move(args)};
3216 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
3219 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3222 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3224 static constexpr llvm::Intrinsic::ID IDs[] = {
3225 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta,
3226 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster,
3227 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta,
3228 llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster};
3229 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3230 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta,
3231 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster,
3232 llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta,
3234 nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster};
3235 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3239 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3246 bool hasCount =
static_cast<bool>(thisOp.getCount());
3248 (
id == llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta))
3249 return {llvm::Intrinsic::nvvm_mbarrier_arrive_shared, {mbar}};
3253 llvm::Value *count =
3255 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3256 return {id, {mbar, count}};
3261 auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op);
3264 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3267 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3269 static constexpr llvm::Intrinsic::ID IDs[] = {
3270 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta,
3271 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster,
3272 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta,
3273 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster};
3274 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3275 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta,
3277 nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster,
3279 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta,
3281 nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster};
3282 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3286 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3292 bool hasCount =
static_cast<bool>(thisOp.getCount());
3293 llvm::Value *count =
3295 : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1);
3297 return {id, {mbar, count}};
3300bool MBarrierArriveExpectTxOp::getAsmValues(
3307 for (
auto val : getOperands())
3315 auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op);
3318 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3321 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3324 static constexpr llvm::Intrinsic::ID IDs[] = {
3325 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta,
3326 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster,
3327 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta,
3328 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster};
3329 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3330 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta,
3331 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster,
3332 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta,
3333 llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster};
3335 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3338 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3339 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3344 return {id, {mbar, txcount}};
3349 auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op);
3352 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3355 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0);
3358 static constexpr llvm::Intrinsic::ID IDs[] = {
3359 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta,
3360 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster,
3361 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta,
3362 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster};
3363 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3364 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta,
3365 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster,
3366 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta,
3367 llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster};
3369 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3372 llvm::Value *txcount = mt.
lookupValue(thisOp.getTxcount());
3373 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3378 return {id, {mbar, txcount}};
3383 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
3385 llvm::Intrinsic::ID
id =
3386 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
3387 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
3391 args.push_back(mt.
lookupValue(thisOp.getCount()));
3393 return {id, std::move(args)};
3398 auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op);
3400 llvm::Intrinsic::ID
id =
3401 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared
3402 : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete;
3406 args.push_back(mt.
lookupValue(thisOp.getCount()));
3408 return {id, std::move(args)};
3413 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
3414 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3415 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3418 size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
3421 static constexpr llvm::Intrinsic::ID IDs[] = {
3422 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
3423 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
3424 llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
3425 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
3426 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3427 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
3428 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
3429 llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
3430 llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
3432 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3435 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3436 llvm::Value *input = mt.
lookupValue(thisOp.getStateOrPhase());
3441 return {id, {mbar, input}};
3446 auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op);
3447 bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
3448 bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
3449 bool hasTicks =
static_cast<bool>(thisOp.getTicks());
3453 size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) |
3454 (isPhaseParity ? 1 : 0);
3457 static constexpr llvm::Intrinsic::ID IDs[] = {
3458 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta,
3459 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta,
3460 llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta,
3461 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta,
3462 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta,
3463 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta,
3464 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta,
3465 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta};
3466 static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
3467 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta,
3468 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta,
3469 llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta,
3470 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta,
3471 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta,
3472 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta,
3473 llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta,
3474 llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta};
3476 auto id = thisOp.getRelaxed() ? relaxedIDs[
index] : IDs[
index];
3479 llvm::Value *mbar = mt.
lookupValue(thisOp.getAddr());
3486 args.push_back(mbar);
3487 args.push_back(mt.
lookupValue(thisOp.getStateOrPhase()));
3489 args.push_back(mt.
lookupValue(thisOp.getTicks()));
3491 return {id, std::move(args)};
3496 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
3499 llvm::Intrinsic::ID id;
3500 if (thisOp.getNoinc()) {
3501 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
3502 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
3504 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
3505 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
3511#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
3512 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
3514#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
3515 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
3520 llvm::Intrinsic::ID id;
3522 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
3523 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
3524 switch (cpAsyncOp.getSize()) {
3532 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
3537 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
3541 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
3542 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
3544 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
3551 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
3553 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
3556 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3560 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3561 llvm::Value *i64Unused =
3562 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3563 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3564 args.push_back(builder.getInt1(hasCacheHint));
3566 return {id, std::move(args)};
3571 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
3575 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3577 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3581 mlir::Value multicastMask = thisOp.getMulticastMask();
3582 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
3585 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
3586 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
3592 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3593 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
3594 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3598 args.push_back(builder.getInt1(hasMulticastMask));
3599 args.push_back(builder.getInt1(hasCacheHint));
3601 llvm::Intrinsic::ID
id =
3603 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
3604 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
3606 return {id, std::move(args)};
3611 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
3613 llvm::Intrinsic::ID
id =
3614 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
3617 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3618 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3622 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3623 llvm::Value *i64Unused =
3624 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3625 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3626 args.push_back(builder.getInt1(hasCacheHint));
3629 if (
mlir::Value byteMask = thisOp.getByteMask()) {
3631 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
3634 return {id, std::move(args)};
3637bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
3644 for (
auto val : getOperands())
3651CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3653 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
3654 const bool isCTAOnly = thisOp.getIsCTAOnly();
3658 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
3660 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3670 const bool hasMC =
static_cast<bool>(mcMask);
3671 llvm::Value *i16Zero =
3672 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
3676 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3677 llvm::Value *i64Zero =
3678 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3684 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
3686 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
3690 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
3691 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3692 args.push_back(builder.getInt1(hasMC));
3693 args.push_back(builder.getInt1(hasCacheHint));
3697 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
3698 args.push_back(builder.getInt1(hasCacheHint));
3701 constexpr size_t numDims = 5;
3702 constexpr size_t numModes = 5;
3703 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
3704 using TableTy = std::array<rowTy, numModes>;
3705 static constexpr TableTy IDTable{
3706 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
3707 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
3708 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
3709 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
3710 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
3712 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
3713 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
3714 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
3716 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
3717 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
3718 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
3720 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
3721 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
3722 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
3724 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
3726 static constexpr TableTy IDTableCTA{
3728 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
3729 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
3730 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
3731 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
3732 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
3734 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
3735 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
3736 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
3738 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
3739 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
3740 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
3742 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
3743 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
3744 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
3746 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
3749 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
3750 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
3751 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
3752 size_t mode =
static_cast<size_t>(thisOp.getMode());
3753 size_t dim = thisOp.getCoordinates().size();
3754 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
3756 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
3758 return {id, std::move(args)};
3763 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
3767 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3769 for (
auto v : thisOp.getCoordinates())
3771 for (
auto v : thisOp.getIm2colOffsets())
3775 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3776 llvm::Value *i64Unused =
3777 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3778 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3779 args.push_back(builder.getInt1(hasCacheHint));
3781 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3782 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3783 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
3784 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
3785 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
3786 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
3787 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
3789 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
3790 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
3791 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
3793 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
3794 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
3795 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
3797 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
3798 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
3799 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
3800 {NI, NI, NI, NI, NI,
3801 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
3803 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
3804 "TMALoadModes must match number of rows in IDTable");
3805 size_t mode =
static_cast<size_t>(thisOp.getMode());
3806 size_t dim = thisOp.getCoordinates().size();
3807 llvm::Intrinsic::ID
id = IDTable[mode][dim];
3808 if (
id == llvm::Intrinsic::not_intrinsic)
3809 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
3811 return {id, std::move(args)};
3815CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
3817 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
3821 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3822 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3824 for (
auto v : thisOp.getCoordinates())
3828 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3829 llvm::Value *i64Unused =
3830 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
3831 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
3832 args.push_back(builder.getInt1(hasCacheHint));
3834 const unsigned NI = llvm::Intrinsic::not_intrinsic;
3835 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
3836 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
3837 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
3838 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
3839 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
3840 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
3841 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
3842 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
3843 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
3844 {NI, NI, NI, NI, NI,
3845 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
3847 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
3848 "TMAStoreModes must match number of rows in IDTable");
3849 size_t mode =
static_cast<size_t>(thisOp.getMode());
3850 size_t dim = thisOp.getCoordinates().size();
3851 llvm::Intrinsic::ID
id = IDTable[mode][dim];
3852 if (
id == llvm::Intrinsic::not_intrinsic)
3854 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
3856 return {id, std::move(args)};
3861 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
3869 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
3870 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
3872 for (
Value v : thisOp.getCoordinates())
3876 const bool hasCacheHint =
static_cast<bool>(cacheHint);
3877 llvm::Value *i64ZeroValue =
3878 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
3879 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
3880 args.push_back(builder.getInt1(hasCacheHint));
3882 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
3884 constexpr unsigned numRedKinds = 8;
3885 constexpr unsigned numLayouts = 2;
3886 constexpr unsigned maxDim = 5;
3887 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
3888 using layoutTable = std::array<row, numLayouts>;
3889 using fullTable = std::array<layoutTable, numRedKinds>;
3890 static constexpr fullTable IDTable{
3893 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
3894 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
3895 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
3896 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
3897 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
3899 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
3900 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
3901 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
3904 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
3905 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
3906 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
3907 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
3908 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
3910 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
3911 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
3912 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
3915 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
3916 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
3917 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
3918 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
3919 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
3921 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
3922 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
3923 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
3926 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
3927 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
3928 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
3929 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
3930 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
3932 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
3933 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
3934 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
3937 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
3938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
3939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
3940 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
3941 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
3943 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
3944 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
3945 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
3948 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
3949 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
3950 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
3951 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
3952 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
3954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
3955 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
3956 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
3959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
3960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
3961 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
3962 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
3963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
3965 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
3966 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
3967 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
3970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
3971 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
3972 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
3973 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
3974 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
3976 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
3977 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
3979 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
3981 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
3982 "TMAReduxKinds must match number of rows in IDTable");
3984 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
3985 size_t mode =
static_cast<size_t>(thisOp.getMode());
3986 size_t dim = thisOp.getCoordinates().size();
3988 assert(redKind < IDTable.size() &&
3989 "Invalid redKind for CpAsyncBulkTensorReduceOp");
3990 assert(mode < IDTable[redKind].size() &&
3991 "Invalid mode for CpAsyncBulkTensorReduceOp");
3992 assert(dim < IDTable[redKind][mode].size() &&
3993 "Invalid dim for CpAsyncBulkTensorReduceOp");
3995 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
3998 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
4000 return {intrinsicID, std::move(args)};
4005#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4006 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
4007 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
4009#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
4010 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
4011 : CVT_F2TF32_ID_IMPL(rnd, relu, )
4014ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4015 NVVM::SaturationMode sat,
bool hasRelu) {
4016 using RndMode = NVVM::FPRoundingMode;
4017 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4026 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
4031ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
4033 llvm::IRBuilderBase &builder) {
4038 bool hasRelu = op.getRelu();
4040 llvm::Intrinsic::ID intId =
4041 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
4042 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
4044 return {intId, std::move(args)};
4047#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
4048 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
4049 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
4051llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
4054 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4057 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4061 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
4062 return llvm::Intrinsic::not_intrinsic;
4066#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
4067 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
4068 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
4070#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
4071 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
4072 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
4075ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
4076 NVVM::SaturationMode sat,
bool hasRelu) {
4077 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4078 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
4079 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
4082 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4085 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4088 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
4089 if (hasRoundingModeRZ)
4091 else if (hasRoundingModeRP)
4094 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4097 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
4098 return llvm::Intrinsic::not_intrinsic;
4102#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
4103 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
4104 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
4106llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
4109 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4112 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4116 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
4117 return llvm::Intrinsic::not_intrinsic;
4121#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
4122 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
4123 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
4126ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
4127 NVVM::SaturationMode sat) {
4128 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
4130 case NVVM::FPRoundingMode::RZ:
4132 case NVVM::FPRoundingMode::RP:
4135 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
4141 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
4143 bool hasRelu = curOp.getRelu();
4145 llvm::Intrinsic::ID intId =
4147 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
4148 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
4149 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
4151 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
4152 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
4153 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
4156 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
4157 return llvm::Intrinsic::not_intrinsic;
4160 llvm::Value *packedI16 =
4161 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4162 llvm::Type::getInt16Ty(builder.getContext()));
4164 return {intId, {packedI16}};
4169 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
4171 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
4172 llvm::Value *packedI16 =
4173 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4174 llvm::Type::getInt16Ty(builder.getContext()));
4176 return {intId, {packedI16}};
4181 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
4183 bool hasRelu = curOp.getRelu();
4185 llvm::Intrinsic::ID intId =
4187 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
4188 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
4189 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
4191 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
4192 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
4193 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
4196 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
4197 return llvm::Intrinsic::not_intrinsic;
4200 llvm::Value *packedI16 =
4201 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
4202 llvm::Type::getInt16Ty(builder.getContext()));
4204 return {intId, {packedI16}};
4209 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
4211 bool hasRelu = curOp.getRelu();
4213 llvm::Intrinsic::ID intId =
4215 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
4216 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
4217 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
4220 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
4221 return llvm::Intrinsic::not_intrinsic;
4224 llvm::Value *extendedI16 =
4225 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
4226 llvm::Type::getInt16Ty(builder.getContext()));
4228 return {intId, {extendedI16}};
4232Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
4235 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
4236 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4238 bool isShared = as == NVVMMemorySpace::Shared;
4239 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4241 llvm::Intrinsic::ID id;
4243 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
4244 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
4246 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
4247 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
4257llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
4260 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
4261 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
4262 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
4263 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
4272#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
4273 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
4274 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
4276#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
4277 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
4278 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
4281Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
4284 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
4285 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
4287 bool isShared = as == NVVMMemorySpace::Shared;
4288 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
4289 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
4291 llvm::Intrinsic::ID
id =
4298 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
4303#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
4304 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
4306#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
4307 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
4308 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
4310#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
4312 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
4313 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
4314 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
4315 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
4316 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
4320ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
4322 llvm::IRBuilderBase &builder) {
4323 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4324 llvm::Intrinsic::nvvm_ff2f16x2_rn,
4325 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
4326 llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
4327 llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
4329 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4330 llvm::Intrinsic::nvvm_ff2f16x2_rz,
4331 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
4332 llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
4333 llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
4335 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4336 llvm::Intrinsic::nvvm_ff2f16x2_rs,
4337 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
4338 llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
4339 llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
4342 unsigned hasRelu = op.getRelu() ? 1 : 0;
4343 unsigned hasSatFinite =
4344 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4347 unsigned idx = (hasSatFinite << 1) | hasRelu;
4352 if (op.getRandomBits())
4353 args.push_back(mt.
lookupValue(op.getRandomBits()));
4355 switch (op.getRnd()) {
4356 case FPRoundingMode::RN:
4357 return {rndRNIds[idx], std::move(args)};
4358 case FPRoundingMode::RZ:
4359 return {rndRZIds[idx], std::move(args)};
4360 case FPRoundingMode::RS:
4361 return {rndRSIds[idx], std::move(args)};
4363 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToF16x2Op");
4368ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
4370 llvm::IRBuilderBase &builder) {
4371 static constexpr llvm::Intrinsic::ID rndRNIds[] = {
4372 llvm::Intrinsic::nvvm_ff2bf16x2_rn,
4373 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
4374 llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
4375 llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
4377 static constexpr llvm::Intrinsic::ID rndRZIds[] = {
4378 llvm::Intrinsic::nvvm_ff2bf16x2_rz,
4379 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
4380 llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
4381 llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
4383 static constexpr llvm::Intrinsic::ID rndRSIds[] = {
4384 llvm::Intrinsic::nvvm_ff2bf16x2_rs,
4385 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
4386 llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
4387 llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
4390 unsigned hasRelu = op.getRelu() ? 1 : 0;
4391 unsigned hasSatFinite =
4392 (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
4395 unsigned idx = (hasSatFinite << 1) | hasRelu;
4400 if (op.getRandomBits())
4401 args.push_back(mt.
lookupValue(op.getRandomBits()));
4403 switch (op.getRnd()) {
4404 case FPRoundingMode::RN:
4405 return {rndRNIds[idx], std::move(args)};
4406 case FPRoundingMode::RZ:
4407 return {rndRZIds[idx], std::move(args)};
4408 case FPRoundingMode::RS:
4409 return {rndRSIds[idx], std::move(args)};
4411 llvm_unreachable(
"Invalid rounding mode for ConvertF32x2ToBF16x2Op");
4415llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
4417 bool hasRelu = getRelu();
4420 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
4421 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
4422 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
4424 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
4425 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
4426 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
4429 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
4430 return llvm::Intrinsic::not_intrinsic;
4434llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
4436 bool hasRelu = getRelu();
4439 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
4440 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
4441 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
4443 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
4444 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
4445 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
4448 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
4449 return llvm::Intrinsic::not_intrinsic;
4453llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
4455 bool hasRelu = getRelu();
4458 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
4459 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
4460 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
4463 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
4464 return llvm::Intrinsic::not_intrinsic;
4468llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
4469 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
4470 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
4471 auto srcFmt = curOp.getSrcFormat();
4472 auto mc = curOp.getMulticast();
4474 switch (curOp.getShape()) {
4475 case Tcgen05CpShape::SHAPE_128x256b:
4477 case Tcgen05CpShape::SHAPE_128x128b:
4479 case Tcgen05CpShape::SHAPE_4x256b:
4481 case Tcgen05CpShape::SHAPE_32x128b:
4483 case Tcgen05CpShape::SHAPE_64x128b:
4484 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
4488 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
4495 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
4497 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
4502LogicalResult Tcgen05LdOp::verify() {
4504 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4507 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
4508 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
4510 auto resTy = getRes().getType();
4511 unsigned resLen = isa<VectorType>(resTy)
4512 ? llvm::cast<VectorType>(resTy).getNumElements()
4515 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
4516 "{1} in tcgen05.ld Op",
4517 resLen, stringifyEnum(
getShape())));
4522LogicalResult Tcgen05StOp::verify() {
4524 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
4527 auto valTy = getVal().getType();
4528 unsigned valLen = isa<VectorType>(valTy)
4529 ? llvm::cast<VectorType>(valTy).getNumElements()
4532 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
4533 "{1} in tcgen05.st Op",
4534 valLen, stringifyEnum(
getShape())));
4544 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
4545 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
4546 rangeAttr.getLower(), rangeAttr.getUpper()});
4554 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
4558 const llvm::APInt &lower = rangeAttr->getLower();
4559 const llvm::APInt &upper = rangeAttr->getUpper();
4562 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
4563 unsigned bitWidth = lower.getBitWidth();
4564 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
4565 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
4567 "invalid range attribute: Lower == Upper, but they aren't min (")
4568 << llvm::toString(minVal, 10,
false) <<
") or max ("
4569 << llvm::toString(maxVal, 10,
false)
4570 <<
") value! This is an invalid constant range.";
4577 llvm::IRBuilderBase &builder) {
4578 return builder.CreateBitCast(arg,
4579 llvm::Type::getInt32Ty(builder.getContext()));
4584 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
4591 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4592 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4593 unsigned type = (isASigned << 1) | isBSigned;
4594 const llvm::Intrinsic::ID ids[] = {
4595 llvm::Intrinsic::nvvm_idp4a_u_u,
4596 llvm::Intrinsic::nvvm_idp4a_u_s,
4597 llvm::Intrinsic::nvvm_idp4a_s_u,
4598 llvm::Intrinsic::nvvm_idp4a_s_s,
4600 return {ids[type], args};
4605 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
4610 args.push_back(builder.getInt1(curOp.getBHi()));
4613 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
4614 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
4615 unsigned type = (isASigned << 1) | isBSigned;
4616 const llvm::Intrinsic::ID ids[] = {
4617 llvm::Intrinsic::nvvm_idp2a_u_u,
4618 llvm::Intrinsic::nvvm_idp2a_u_s,
4619 llvm::Intrinsic::nvvm_idp2a_s_u,
4620 llvm::Intrinsic::nvvm_idp2a_s_s,
4622 return {ids[type], args};
4626 llvm::IRBuilderBase &builder) {
4627 return builder.CreateAddrSpaceCast(
4629 llvm::PointerType::get(builder.getContext(),
4630 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
4634PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
4636 llvm::IRBuilderBase &builder) {
4637 using MemSpace = NVVM::NVVMMemorySpace;
4638 using CacheLevel = NVVM::PrefetchCacheLevel;
4640 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
4641 std::optional<NVVM::CacheEvictionPriority> evictPriority =
4642 op.getEvictPriority();
4643 unsigned addressSpace =
4644 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
4652 if (op.getTensormap())
4653 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
4655 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
4657 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
4658 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
4660 if (evictPriority && *cacheLevel == CacheLevel::L2) {
4661 switch (*evictPriority) {
4662 case NVVM::CacheEvictionPriority::EvictLast:
4663 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
4664 case NVVM::CacheEvictionPriority::EvictNormal:
4665 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
4667 llvm_unreachable(
"Invalid cache eviction priority");
4671 switch (
static_cast<MemSpace
>(addressSpace)) {
4672 case MemSpace::Generic:
4673 return *cacheLevel == CacheLevel::L1
4675 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
4676 case MemSpace::Global:
4677 return *cacheLevel == CacheLevel::L1
4679 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
4681 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
4682 case MemSpace::Local:
4683 return *cacheLevel == CacheLevel::L1
4685 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
4687 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
4689 llvm_unreachable(
"Invalid pointer address space");
4693bool NVVM::InlinePtxOp::getAsmValues(
4697 for (
auto arg : getReadWriteArgs())
4699 for (
auto arg : getResults())
4701 for (
auto arg : getReadOnlyArgs())
4708NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
4710 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
4712 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
4713 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
4715 llvm::Intrinsic::ID intrinsicID =
4716 curOp.getMulticast()
4718 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
4719 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
4721 return {intrinsicID, args};
4724NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
4726 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
4728 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
4730 llvm::Intrinsic::ID intrinsicID;
4732 switch (curOp.getQueryType()) {
4733 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
4735 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
4737 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
4738 intrinsicID = llvm::Intrinsic::
4739 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
4741 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
4742 intrinsicID = llvm::Intrinsic::
4743 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
4745 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
4746 intrinsicID = llvm::Intrinsic::
4747 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
4750 return {intrinsicID, args};
4755 llvm::IRBuilderBase &builder) {
4756 auto thisOp = cast<NVVM::PermuteOp>(op);
4757 NVVM::PermuteMode mode = thisOp.getMode();
4759 static constexpr llvm::Intrinsic::ID IDs[] = {
4760 llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
4761 llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
4762 llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
4763 llvm::Intrinsic::nvvm_prmt_rc16};
4765 unsigned modeIndex =
static_cast<unsigned>(mode);
4773 args.push_back(mt.
lookupValue(thisOp.getSelector()));
4775 return {IDs[modeIndex], args};
4784 llvm::IRBuilderBase &builder) {
4786 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
4789 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
4792 const bool isATensor = isa<llvm::PointerType>(
A->getType());
4795 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
4796 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
4797 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
4799 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4800 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4801 using IsATensorArray = std::array<CtaGroupArray, 2>;
4802 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4803 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4806 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
4812 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
4814 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
4818 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4819 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4823 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
4824 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
4830 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
4832 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
4836 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4837 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4841 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
4842 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
4848 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
4851 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
4856 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
4858 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
4863 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
4865 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
4871 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
4875 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
4880 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
4882 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
4886 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
4888 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
4891 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
4892 bool hasScaleInputD = ScaleInputD !=
nullptr;
4894 llvm::Value *DisableOutputLane =
4896 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
4898 const unsigned ctaGroup =
4901 llvm::Intrinsic::ID ID =
4902 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
4903 [ctaGroup - 1][thisOp.getAShift()];
4905 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
4908 args.push_back(ScaleInputD);
4910 if (hasDisableOutputLane)
4911 args.push_back(DisableOutputLane);
4913 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
4915 if (!hasDisableOutputLane)
4916 args.push_back(builder.getInt32(ctaGroup));
4919 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
4926 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
4927 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
4929 if (disableOutputLane) {
4930 mlir::VectorType disableOutputLaneType =
4931 cast<mlir::VectorType>(disableOutputLane.
getType());
4932 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
4933 disableOutputLaneType.getNumElements() != 4) ||
4934 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
4935 disableOutputLaneType.getNumElements() != 8))
4936 return emitError(loc) <<
"Disable Output Lane of length "
4937 << disableOutputLaneType.getNumElements()
4938 <<
" is incompatible with CtaGroupAttr";
4941 if (hasAShift && !isATensor)
4943 loc,
"A-shift can be applied only when matrix A is in tensor memory");
4945 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
4946 collectorOp == Tcgen05MMACollectorOp::USE))
4948 loc,
"Cannot use collector buffer operation fill or use with ashift");
4953LogicalResult Tcgen05MMAOp::verify() {
4955 getDisableOutputLane(), getCtaGroup(), getAShift(),
4956 getCollectorOp(), getLoc());
4966 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
4969 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
4972 bool isATensor = isa<llvm::PointerType>(
A->getType());
4975 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
4976 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
4977 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
4978 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
4980 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
4981 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
4982 using IsATensorArray = std::array<CtaGroupArray, 2>;
4983 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
4984 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
4987 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
4993 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
4995 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
4999 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5000 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5004 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
5005 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
5011 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5014 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
5019 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5020 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5024 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
5025 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
5032 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
5036 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
5041 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
5043 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
5048 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
5050 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
5056 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
5060 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
5065 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
5067 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
5071 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
5073 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
5076 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
5077 bool hasScaleInputD = ScaleInputD !=
nullptr;
5079 llvm::Value *DisableOutputLane =
5081 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
5086 llvm::Intrinsic::ID ID =
5087 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
5088 [ctaGroup - 1][thisOp.getAShift()];
5090 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
5093 args.push_back(ScaleInputD);
5095 if (hasDisableOutputLane)
5096 args.push_back(DisableOutputLane);
5098 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5100 if (!hasDisableOutputLane)
5101 args.push_back(builder.getInt32(ctaGroup));
5104 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5109LogicalResult Tcgen05MMASparseOp::verify() {
5111 getDisableOutputLane(), getCtaGroup(), getAShift(),
5112 getCollectorOp(), getLoc());
5122 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
5125 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5128 bool isATensor = isa<llvm::PointerType>(
A->getType());
5131 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5132 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5133 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5134 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5135 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5136 args.push_back(builder.getInt32(
5139 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5141 auto kind = thisOp.getKind();
5142 auto blockScale = thisOp.getBlockScale();
5143 llvm::Intrinsic::ID ID = [&]() {
5144 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5145 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5146 return isATensor ? llvm::Intrinsic::
5147 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
5149 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
5150 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5153 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
5155 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
5157 }
else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5158 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5160 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
5161 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
5162 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5163 return isATensor ? llvm::Intrinsic::
5164 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
5166 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
5168 }
else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5169 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5172 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
5174 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
5176 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5179 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
5181 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
5184 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
5191 NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
5192 NVVM::Tcgen05MMABlockScale blockScale,
Location loc) {
5194 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
5195 kind == MMABlockScaleKind::MXF4NVF4)
5196 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
5198 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
5199 kind != MMABlockScaleKind::MXF4NVF4)
5201 llvm::formatv(
"{} kind does not support block16 attribute",
5202 stringifyEnum(kind)));
5207LogicalResult Tcgen05MMABlockScaleOp::verify() {
5209 getBlockScale(), getLoc());
5219 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
5222 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5225 bool isATensor = isa<llvm::PointerType>(
A->getType());
5228 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5229 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5230 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5231 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5232 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
5233 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
5234 args.push_back(builder.getInt32(
5237 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5239 auto kind = thisOp.getKind();
5240 auto blockScale = thisOp.getBlockScale();
5241 llvm::Intrinsic::ID ID = [&]() {
5242 if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
5243 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5244 return isATensor ? llvm::Intrinsic::
5245 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
5247 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
5248 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5251 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
5253 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
5255 }
else if (kind == NVVM::MMABlockScaleKind::MXF4) {
5256 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
5257 return isATensor ? llvm::Intrinsic::
5258 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
5260 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
5261 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5264 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
5266 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
5268 }
else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
5269 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
5272 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
5274 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
5276 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
5279 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
5281 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
5284 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
5290LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
5292 getBlockScale(), getLoc());
5302 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
5305 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5308 bool isATensor = isa<llvm::PointerType>(
A->getType());
5311 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5312 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5313 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5315 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5319 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
5320 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
5322 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
5323 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
5325 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5327 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5329 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5341 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
5344 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
5347 bool isATensor = isa<llvm::PointerType>(
A->getType());
5350 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
5351 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
5352 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
5353 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
5355 mlir::Value ZeroColMask = thisOp.getZeroColMask();
5360 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
5361 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
5363 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
5364 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
5366 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
5368 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
5370 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
5380void NVVMDialect::initialize() {
5383#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5386#define GET_ATTRDEF_LIST
5387#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
5392 allowUnknownOperations();
5393 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
5394 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
5397LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
5399 StringAttr attrName = attr.
getName();
5401 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
5402 if (!isa<LLVM::LLVMFuncOp>(op)) {
5403 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
5404 <<
"' attribute attached to unexpected op";
5409 if (attrName == NVVMDialect::getMaxntidAttrName() ||
5410 attrName == NVVMDialect::getReqntidAttrName() ||
5411 attrName == NVVMDialect::getClusterDimAttrName()) {
5412 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
5413 if (!values || values.empty() || values.size() > 3) {
5416 <<
"' attribute must be integer array with maximum 3 index";
5421 if (attrName == NVVMDialect::getMinctasmAttrName() ||
5422 attrName == NVVMDialect::getMaxnregAttrName() ||
5423 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
5424 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
5426 <<
"'" << attrName <<
"' attribute must be integer constant";
5430 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
5431 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
5432 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
5434 <<
"'" << attrName <<
"' attribute must be used along with "
5435 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
5436 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
5443LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
5444 unsigned regionIndex,
5447 auto funcOp = dyn_cast<FunctionOpInterface>(op);
5451 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
5452 StringAttr attrName = argAttr.
getName();
5453 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
5457 <<
"' attribute must be present only on kernel arguments";
5459 if (!isa<UnitAttr>(argAttr.
getValue()))
5460 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
5461 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
5464 <<
"' attribute requires the argument to also have attribute '"
5465 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
5476unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
5477 return static_cast<unsigned>(getValue());
5480bool NVVMMemorySpaceAttr::isValidLoad(
5481 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5482 const ::mlir::DataLayout *dataLayout,
5488bool NVVMMemorySpaceAttr::isValidStore(
5489 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
5490 const ::mlir::DataLayout *dataLayout,
5496bool NVVMMemorySpaceAttr::isValidAtomicOp(
5497 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
5498 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
5501 assert(
false &&
"unimplemented, see TODO in the source.");
5505bool NVVMMemorySpaceAttr::isValidAtomicXchg(
5506 Type type, ptr::AtomicOrdering successOrdering,
5507 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
5508 const ::mlir::DataLayout *dataLayout,
5511 assert(
false &&
"unimplemented, see TODO in the source.");
5515bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
5519 assert(
false &&
"unimplemented, see TODO in the source.");
5523bool NVVMMemorySpaceAttr::isValidPtrIntCast(
5528 assert(
false &&
"unimplemented, see TODO in the source.");
5537 int optLevel, StringRef triple, StringRef chip,
5538 StringRef features, DictionaryAttr flags,
5540 if (optLevel < 0 || optLevel > 3) {
5541 emitError() <<
"The optimization level must be a number between 0 and 3.";
5544 if (triple.empty()) {
5545 emitError() <<
"The target triple cannot be empty.";
5549 emitError() <<
"The target chip cannot be empty.";
5552 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
5553 return mlir::isa_and_nonnull<StringAttr>(attr);
5555 emitError() <<
"All the elements in the `link` array must be strings.";
5561LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
5562 if (!getVerifyTarget())
5565 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
5568 "NVVM target attribute must be attached to a GPU module");
5571 const NVVMCheckSMVersion targetSMVersion =
5575 "Minimum NVVM target SM version is sm_20");
5579 ->
walk([&](Operation *op) {
5580 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
5581 const NVVMCheckSMVersion requirement =
5582 reqOp.getRequiredMinSMVersion();
5584 op->
emitOpError() <<
"is not supported on " << getChip();
5596#define GET_OP_CLASSES
5597#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
5599#define GET_ATTRDEF_CLASSES
5600#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
static void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
static void addBlockScaleProperties(OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
static void printOperandList(OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static llvm::Value * castPtrToAddrSpace(llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
static LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ®s)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static MMATypes inferPtxTypeFromResult(OpTy op)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static LogicalResult parseMmaTypeSignature(OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
static bool isPtrInSharedClusterSpace(mlir::Value ptr)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static bool isPtrInGenericSpace(mlir::Value ptr)
static void processOperandFragments(Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > ®Types, SmallVectorImpl< StringRef > &ignoreAttrNames)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.