31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/IR/NVVMIntrinsicUtils.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/FormatVariadic.h"
37#include "llvm/Support/NVPTXAddrSpace.h"
38#include "llvm/Support/raw_ostream.h"
46#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
47#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
56 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
57 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
65static llvm::nvvm::CTAGroupKind
68 case NVVM::CTAGroupKind::CTA_1:
69 return llvm::nvvm::CTAGroupKind::CG_1;
70 case NVVM::CTAGroupKind::CTA_2:
71 return llvm::nvvm::CTAGroupKind::CG_2;
73 llvm_unreachable(
"unsupported cta_group value");
85 size_t numIm2ColOffsets,
87 if (tensorDims < 1 || tensorDims > 5)
88 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
96 "to use im2col mode, the tensor has to be at least 3-dimensional");
98 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
100 loc,
"im2col offsets must be 2 less than number of coordinates");
105LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
106 TMAStoreMode mode = getMode();
110 if (getPredicate()) {
111 if (mode != TMAStoreMode::TILE)
112 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
113 if (getL2CacheHint())
114 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
119 case TMAStoreMode::TILE:
121 case TMAStoreMode::IM2COL:
123 case TMAStoreMode::TILE_SCATTER4:
125 return emitError(
"Scatter4 mode expects 5 coordinates");
130LogicalResult CpAsyncOp::verify() {
131 if (getModifier() != LoadCacheModifierKind::CG &&
132 getModifier() != LoadCacheModifierKind::CA)
133 return emitError(
"Only CG and CA cache modifiers are supported.");
134 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
135 return emitError(
"expected byte size to be either 4, 8 or 16.");
136 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
137 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
144 if (tensorDims < 1 || tensorDims > 5)
145 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
147 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
148 size_t expectedIm2colOff) -> LogicalResult {
149 if (isIm2col && (tensorDims < 3))
151 <<
"to use " << stringifyEnum(mode)
152 <<
" mode, the tensor has to be at least 3-dimensional";
154 if (numIm2colOff != expectedIm2colOff)
155 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
156 <<
" (provided " << numIm2colOff <<
")";
162 case TMALoadMode::TILE:
163 return checkTMALoadParams(mode,
false, 0);
164 case TMALoadMode::IM2COL:
165 return checkTMALoadParams(mode,
true, tensorDims - 2);
166 case TMALoadMode::IM2COL_W:
167 case TMALoadMode::IM2COL_W_128:
168 return checkTMALoadParams(mode,
true, 2);
169 case TMALoadMode::TILE_GATHER4:
170 return (tensorDims == 5)
171 ? checkTMALoadParams(mode,
false, 0)
172 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
177LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
179 getMode(), getLoc());
182LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
183 TMALoadMode mode = getMode();
184 bool isCTAOnly = getIsCTAOnly();
185 if (getPredicate()) {
187 return emitError(
"Predicate is supported only for shared::cluster mode.");
188 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
190 "Predicate is supported only for Tile and Im2col modes.");
192 NVVMMemorySpace expectedAS =
193 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
194 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
196 if (AS != expectedAS)
199 ?
"Shared::cta destination requires address-space 3."
200 :
"Shared::cluster destination requires address-space 7.");
203 if (getMulticastMask())
204 return emitError(
"Multicast is not supported with shared::cta mode.");
206 return emitError(
"CTAGroup is not supported with shared::cta mode.");
211 getMode(), getLoc());
214LogicalResult CpAsyncBulkTensorReduceOp::verify() {
215 TMAStoreMode mode = getMode();
218 case TMAStoreMode::TILE:
220 case TMAStoreMode::IM2COL:
222 case TMAStoreMode::TILE_SCATTER4:
223 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
228LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
230 if (isSharedCTA && getMulticastMask())
231 return emitError(
"Multicast is not supported with shared::cta mode.");
236LogicalResult ConvertFloatToTF32Op::verify() {
237 using RndMode = NVVM::FPRoundingMode;
241 return emitError(
"Relu not supported with rna rounding mode.");
248 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
253LogicalResult ConvertF32x2ToF6x2Op::verify() {
256 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
258 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
259 << mlir::Float6E3M2FNType::get(ctx)
260 <<
" types are supported for conversions from f32x2 to f6x2.";
265LogicalResult ConvertF32x2ToF8x2Op::verify() {
266 using RndMode = NVVM::FPRoundingMode;
267 using SatMode = NVVM::SaturationMode;
269 bool isRoundingModeRN = getRnd() == RndMode::RN;
270 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
271 bool isRoundingModeRP = getRnd() == RndMode::RP;
272 bool isSatFinite = getSat() == SatMode::SATFINITE;
274 bool hasRelu = getRelu();
279 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
281 if (!isRoundingModeRN) {
282 return emitOpError(
"Only RN rounding mode is supported for "
283 "conversions from f32x2 to ")
284 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
285 << mlir::Float8E5M2Type::get(ctx) <<
" types";
288 return emitOpError(
"Only SATFINITE saturation mode is supported "
291 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
292 << mlir::Float8E5M2Type::get(ctx) <<
" types";
296 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
297 if (!(isRoundingModeRZ || isRoundingModeRP)) {
298 return emitOpError(
"Only RZ and RP rounding modes are supported for "
299 "conversions from f32x2 to ")
300 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
303 return emitOpError(
"relu not supported for conversions to ")
304 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
310 << mlir::Float8E4M3FNType::get(ctx) <<
", "
311 << mlir::Float8E5M2Type::get(ctx) <<
", and "
312 << mlir::Float8E8M0FNUType::get(ctx)
314 "supported for conversions from f32x2 to f8x2";
318LogicalResult ConvertF16x2ToF8x2Op::verify() {
321 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
323 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
324 << mlir::Float8E5M2Type::get(ctx)
325 <<
" types are supported for conversions from f16x2 to f8x2.";
330LogicalResult ConvertBF16x2ToF8x2Op::verify() {
331 using RndMode = NVVM::FPRoundingMode;
333 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
335 <<
" type is supported for conversions from "
339 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
340 return emitOpError(
"Only RZ and RP rounding modes are supported for "
341 "conversions from bf16x2 to f8x2.");
346LogicalResult ConvertF32x2ToF4x2Op::verify() {
349 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
351 << mlir::Float4E2M1FNType::get(ctx)
352 <<
" type is supported for conversions from f32x2 to f4x2.";
357LogicalResult ConvertF8x2ToF16x2Op::verify() {
360 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
362 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
363 << mlir::Float8E5M2Type::get(ctx)
364 <<
" types are supported for conversions from f8x2 to f16x2.";
369LogicalResult ConvertF8x2ToBF16x2Op::verify() {
371 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
373 << mlir::Float8E8M0FNUType::get(ctx)
374 <<
" type is supported for conversions from f8x2 to bf16x2.";
379LogicalResult ConvertF6x2ToF16x2Op::verify() {
382 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
384 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
385 << mlir::Float6E3M2FNType::get(ctx)
386 <<
" types are supported for conversions from f6x2 to f16x2.";
391LogicalResult ConvertF4x2ToF16x2Op::verify() {
394 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
396 << mlir::Float4E2M1FNType::get(ctx)
397 <<
" type is supported for conversions from f4x2 to f16x2.";
406LogicalResult ConvertF32x2ToF16x2Op::verify() {
407 if (getRnd() != FPRoundingMode::RS)
408 return emitOpError(
"Only RS rounding mode is supported for "
409 "conversions from f32x2 to f16x2.");
413LogicalResult ConvertF32x2ToBF16x2Op::verify() {
414 if (getRnd() != FPRoundingMode::RS)
415 return emitOpError(
"Only RS rounding mode is supported for "
416 "conversions from f32x2 to bf16x2.");
420LogicalResult ConvertF32x4ToF8x4Op::verify() {
423 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
425 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
426 << mlir::Float8E5M2Type::get(ctx)
427 <<
" types are supported for conversions from f32x4 to f8x4.";
432LogicalResult ConvertF32x4ToF6x4Op::verify() {
435 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
437 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
438 << mlir::Float6E3M2FNType::get(ctx)
439 <<
" types are supported for conversions from f32x4 to f6x4.";
444LogicalResult ConvertF32x4ToF4x4Op::verify() {
447 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
448 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
449 <<
" type is supported for conversions from "
455LogicalResult BulkStoreOp::verify() {
456 if (getInitVal() != 0)
457 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
461LogicalResult PMEventOp::verify() {
462 auto eventId = getEventId();
463 auto maskedEventId = getMaskedEventId();
464 if (!maskedEventId && !eventId) {
465 return emitOpError() <<
"either `id` or `mask` must be set";
468 if (maskedEventId && eventId) {
469 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
473 if (eventId < 0 || eventId > 15) {
474 return emitOpError() <<
"`id` must be between 0 and 15";
478 return llvm::success();
484std::optional<mlir::NVVM::MMATypes>
485MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
487 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
488 if (operandElType.
isF64())
489 return NVVM::MMATypes::f64;
490 if (operandElType.
isF16() || operandElType == half2Type)
491 return NVVM::MMATypes::f16;
492 if (operandElType.
isF32() && isAccumulator)
493 return NVVM::MMATypes::f32;
494 if (operandElType.
isF32() && !isAccumulator)
495 return NVVM::MMATypes::tf32;
496 if (llvm::isa<IntegerType>(operandElType)) {
498 return NVVM::MMATypes::s32;
502 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
503 if (structType.getBody().empty())
505 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
512 return (type == MMATypes::u4 || type == MMATypes::s4);
516 return (type == MMATypes::u8 || type == MMATypes::s8);
521 type == MMATypes::s32;
524MMATypes MmaOp::accumPtxType() {
525 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
526 getODSOperands(2).getTypes().front(),
true);
527 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
531MMATypes MmaOp::resultPtxType() {
532 std::optional<mlir::NVVM::MMATypes> val =
533 inferOperandMMAType(getResult().
getType(),
true);
534 assert(val.has_value() &&
"result PTX type should always be inferrable");
540 struct OperandFragment {
541 StringRef operandName;
542 StringRef ptxTypeAttr;
543 SmallVector<Value, 4> regs;
544 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
545 : operandName(name), ptxTypeAttr(ptxTypeName) {}
548 std::array<OperandFragment, 3> frags{
549 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
550 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
551 OperandFragment(
"C",
"")};
553 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
555 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
556 auto &frag = frags[fragIdx];
557 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
558 for (
auto operandIdx = varOperandSpec.first;
559 operandIdx < varOperandSpec.first + varOperandSpec.second;
561 frag.regs.push_back(this->getOperand(operandIdx));
562 if (operandIdx == 0) {
563 regTypes.push_back(this->getOperand(operandIdx).
getType());
566 std::optional<MMATypes> inferredType =
567 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
569 ignoreAttrNames.push_back(frag.ptxTypeAttr);
572 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
573 p <<
" " << frag.operandName;
579 for (
const auto &frag : frags) {
580 printMmaOperand(frag);
589 frags[1].regs[0].getType(),
590 frags[2].regs[0].getType()},
599 std::optional<MMAIntOverflow> intOverflow,
600 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
601 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
603 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
608 result.addOperands(operandA);
609 result.addOperands(operandB);
610 result.addOperands(operandC);
612 if (multiplicandPtxTypes) {
613 result.addAttribute(
"multiplicandAPtxType",
614 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
615 result.addAttribute(
"multiplicandBPtxType",
616 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
618 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
619 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
620 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
621 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
624 if (multiplicandLayouts) {
625 result.addAttribute(
"layoutA",
626 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
627 result.addAttribute(
"layoutB",
628 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
630 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
631 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
634 if (intOverflow.has_value())
635 result.addAttribute(
"intOverflowBehavior",
636 MMAIntOverflowAttr::get(ctx, *intOverflow));
637 if (b1Op.has_value())
638 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
640 result.addTypes(resultType);
642 MmaOp::getOperandSegmentSizeAttr(),
644 static_cast<int32_t>(operandB.size()),
645 static_cast<int32_t>(operandC.size())}));
653 struct OperandFragment {
654 std::optional<MMATypes> elemtype;
655 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
656 SmallVector<Type> regTypes;
660 std::array<OperandFragment, 4> frags;
665 auto parseMmaOperand = [&](StringRef operandName,
666 OperandFragment &frag) -> LogicalResult {
677 if (parseMmaOperand(
"A", frags[0]).
failed())
679 if (parseMmaOperand(
"B", frags[1]).
failed())
681 if (parseMmaOperand(
"C", frags[2]).
failed())
696 if (operandTypes.size() != 3)
699 "expected one type for each operand segment but got " +
700 Twine(operandTypes.size()) +
" types");
701 for (
const auto &iter : llvm::enumerate(operandTypes)) {
702 auto &frag = frags[iter.index()];
703 frag.regTypes.resize(frag.regs.size(), iter.value());
707 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
714 frags[3].elemtype = inferOperandMMAType(resultType,
true);
716 std::array<StringRef, 2> names{
"multiplicandAPtxType",
717 "multiplicandBPtxType"};
718 for (
unsigned idx = 0; idx < names.size(); idx++) {
719 const auto &frag = frags[idx];
720 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
721 if (!frag.elemtype.has_value() && !attr.has_value()) {
724 "attribute " + names[idx] +
725 " is not provided explicitly and cannot be inferred");
727 if (!attr.has_value())
729 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
732 result.addTypes(resultType);
733 if (!namedAttributes.
empty())
734 result.addAttributes(namedAttributes);
735 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
737 static_cast<int32_t>(frags[0].regs.size()),
738 static_cast<int32_t>(frags[1].regs.size()),
739 static_cast<int32_t>(frags[2].regs.size()),
744LogicalResult MmaOp::verify() {
746 auto f16Ty = Float16Type::get(context);
747 auto i32Ty = IntegerType::get(context, 32);
748 auto f16x2Ty = VectorType::get(2, f16Ty);
749 auto f32Ty = Float32Type::get(context);
750 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
751 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
754 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
757 auto f16x2x2StructTy =
758 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
760 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
762 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
764 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
765 getShapeAttr().getK()};
771 AllowedShapes allowedShapes;
772 AllowedTypes expectedA;
773 AllowedTypes expectedB;
774 AllowedTypes expectedC;
779 if (mmaShape[0] == 16) {
781 Type multiplicandFragType;
782 switch (*getMultiplicandAPtxType()) {
785 multiplicandFragType = i32Ty;
786 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
787 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
791 multiplicandFragType = i32Ty;
792 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
793 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
797 multiplicandFragType = f16x2Ty;
798 expectedResult.push_back(f16x2x2StructTy);
799 expectedResult.push_back(f32x4StructTy);
813 return emitError(
"invalid shape or multiplicand type: " +
814 stringifyEnum(getMultiplicandAPtxType().value()));
818 expectedResult.push_back(s32x4StructTy);
819 expectedC.emplace_back(4, i32Ty);
820 multiplicandFragType = i32Ty;
822 expectedC.emplace_back(2, f16x2Ty);
823 expectedC.emplace_back(4, f32Ty);
826 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
827 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
828 expectedA.emplace_back(unitA, multiplicandFragType);
829 expectedB.emplace_back(unitB, multiplicandFragType);
830 allowedShapes.push_back({16, 8, kFactor});
831 allowedShapes.push_back({16, 8, kFactor * 2});
833 if (resultPtxType() != accumPtxType())
838 if (mmaShape[0] == 8) {
839 if (*getMultiplicandAPtxType() == MMATypes::f16) {
840 expectedA.emplace_back(2, f16x2Ty);
841 expectedB.emplace_back(2, f16x2Ty);
842 expectedResult.push_back(f16x2x4StructTy);
843 expectedResult.push_back(f32x8StructTy);
844 expectedC.emplace_back(4, f16x2Ty);
845 expectedC.emplace_back(8, f32Ty);
846 allowedShapes.push_back({8, 8, 4});
848 if (*getMultiplicandAPtxType() == MMATypes::f64) {
849 Type f64Ty = Float64Type::get(context);
850 expectedA.emplace_back(1, f64Ty);
851 expectedB.emplace_back(1, f64Ty);
852 expectedC.emplace_back(2, f64Ty);
853 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
855 allowedShapes.push_back({8, 8, 4});
858 expectedA.push_back({i32Ty});
859 expectedB.push_back({i32Ty});
860 expectedC.push_back({i32Ty, i32Ty});
861 expectedResult.push_back(s32x2StructTy);
863 allowedShapes.push_back({8, 8, 32});
865 allowedShapes.push_back({8, 8, 16});
866 if (getMultiplicandAPtxType().value() == MMATypes::b1)
867 allowedShapes.push_back({8, 8, 128});
871 std::string errorMessage;
872 llvm::raw_string_ostream errorStream(errorMessage);
875 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
876 !llvm::is_contained(allowedShapes, mmaShape)) {
877 errorStream <<
"unimplemented variant for MMA shape <";
878 llvm::interleaveComma(mmaShape, errorStream);
884 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
885 for (
const auto &iter : llvm::enumerate(
887 auto spec = this->getODSOperandIndexAndLength(iter.index());
889 operand_type_begin() + spec.first +
891 bool match = llvm::is_contained(iter.value(), operandTySeg);
894 errorStream <<
"Could not match types for the "
895 << operandNames[iter.index()]
896 <<
" operands; expected one of ";
897 for (
const auto &x : iter.value()) {
898 errorStream << x.size() <<
"x" << x[0] <<
" ";
900 errorStream <<
"but got ";
901 llvm::interleaveComma(operandTySeg, errorStream);
907 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
908 return expectedResultType == getResult().getType();
911 <<
"Could not match allowed types for the result; expected one of ";
912 llvm::interleaveComma(expectedResult, errorStream);
913 errorStream <<
" but got " << getResult().getType();
918 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
919 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
927 if (!getIntOverflowBehavior())
929 getIntOverflowBehaviorAttrName().strref() +
937 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
938 getMultiplicandAPtxType() == MMATypes::f16);
942 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
943 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
944 "layoutB = #nvvm.mma_layout<col> for shape <")
945 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
946 <<
"> with element types "
947 << stringifyEnum(*getMultiplicandAPtxType()) <<
" and "
948 << stringifyEnum(*getMultiplicandBPtxType())
949 <<
". Only m8n8k4 with f16 supports other layouts.";
956LogicalResult ShflOp::verify() {
957 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
959 auto verifyTypeError = [&](Twine desc,
Type expectedType,
960 Type actualType) -> LogicalResult {
961 return emitOpError(
"expected " + desc +
" to be of type ")
962 << expectedType <<
" but got " << actualType <<
" instead";
965 if (returnStructType) {
966 if (!getReturnValueAndIsValid())
967 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
968 "specified when the return type is a struct type");
970 if (returnStructType.getBody().size() != 2)
971 return emitOpError(
"expected return type to be a two-element struct");
974 auto resultType = returnStruct[0];
975 if (resultType != getVal().
getType())
976 return verifyTypeError(
"first element in the returned struct",
977 getVal().
getType(), resultType);
979 auto predicateType = returnStruct[1];
980 if (!predicateType.isInteger(1))
981 return verifyTypeError(
"second element in the returned struct",
985 if (getReturnValueAndIsValid())
986 return emitOpError(
"expected return type to be a two-element struct");
989 return verifyTypeError(
"return type", getVal().
getType(),
getType());
995 NVVM::MMAFrag frag,
int nRow,
998 unsigned numberElements = 0;
1002 if (type == NVVM::MMATypes::f16) {
1003 elementType = f16x2;
1004 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
1008 }
else if (type == NVVM::MMATypes::f32) {
1011 }
else if (type == NVVM::MMATypes::f64) {
1013 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
1017 }
else if (type == NVVM::MMATypes::tf32) {
1020 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
1022 int parallelSize = 0;
1023 if (frag == NVVM::MMAFrag::a)
1024 parallelSize = nRow;
1025 if (frag == NVVM::MMAFrag::b)
1026 parallelSize = nCol;
1029 if (parallelSize == 16)
1032 else if (parallelSize == 8)
1034 else if (parallelSize == 32)
1036 }
else if (type == NVVM::MMATypes::s32) {
1040 assert(numberElements != 0 && elementType !=
nullptr);
1041 return std::make_pair(elementType, numberElements);
1044static std::pair<mlir::Type, unsigned>
1048 if (frag == NVVM::MMAFrag::a) {
1051 }
else if (frag == NVVM::MMAFrag::b) {
1058 assert(nRow && nCol);
1062LogicalResult NVVM::WMMALoadOp::verify() {
1063 unsigned addressSpace =
1064 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
1065 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1066 addressSpace != NVVMMemorySpace::Shared)
1067 return emitOpError(
"expected source pointer in memory "
1070 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
1071 getEltype(), getFrag()) == 0)
1072 return emitOpError() <<
"invalid attribute combination";
1077 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
1079 return emitOpError(
"expected destination type to be f64");
1083 Type dstType = LLVM::LLVMStructType::getLiteral(
1086 return emitOpError(
"expected destination type is a structure of ")
1087 << typeInfo.second <<
" elements of type " << typeInfo.first;
1091LogicalResult NVVM::WMMAStoreOp::verify() {
1092 unsigned addressSpace =
1093 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
1094 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1095 addressSpace != NVVMMemorySpace::Shared)
1096 return emitOpError(
"expected operands to be a source pointer in memory "
1099 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
1101 return emitOpError() <<
"invalid attribute combination";
1104 if (getArgs().size() != typeInfo.second)
1105 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
1106 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
1107 return operands.
getType() != typeInfo.first;
1109 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
1113LogicalResult NVVM::WMMAMmaOp::verify() {
1114 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
1115 getLayoutB(), getEltypeA(),
1117 return emitOpError() <<
"invalid attribute combination";
1125 arguments.append(typeInfoA.second, typeInfoA.first);
1126 arguments.append(typeInfoB.second, typeInfoB.first);
1127 arguments.append(typeInfoC.second, typeInfoC.first);
1128 unsigned numArgs = arguments.size();
1129 if (getArgs().size() != numArgs)
1130 return emitOpError() <<
"expected " << numArgs <<
" arguments";
1131 for (
unsigned i = 0; i < numArgs; i++) {
1132 if (getArgs()[i].
getType() != arguments[i])
1133 return emitOpError() <<
"expected argument " << i <<
" to be of type "
1136 Type dstType = LLVM::LLVMStructType::getLiteral(
1139 return emitOpError(
"expected destination type is a structure of ")
1140 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
1144LogicalResult NVVM::LdMatrixOp::verify() {
1146 if (m == 8 && n == 8) {
1147 if (num != 1 && num != 2 && num != 4) {
1148 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
1151 if (getEltType() != LdStMatrixEltType::B16) {
1152 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
1154 }
else if (m == 8 && n == 16) {
1155 if (num != 1 && num != 2 && num != 4) {
1156 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
1159 if (getLayout() != MMALayout::row) {
1160 return emitOpError(
"expected layout to be row for 8x16 matrix");
1162 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1163 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1164 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
1165 "b8x16.b6x16_p32 for 8x16 matrix");
1167 }
else if (m == 16 && n == 16) {
1168 if (num != 1 && num != 2) {
1169 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
1172 if (getLayout() != MMALayout::col) {
1173 return emitOpError(
"expected layout to be col for 16x16 matrix");
1175 if (getEltType() != LdStMatrixEltType::B8 &&
1176 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1177 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1178 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
1179 "b8x16.b6x16_p32 for 16x16 matrix");
1182 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
1186 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1187 if (numElements == 1 &&
getType() != i32)
1188 return emitOpError(
"expected destination type is i32");
1189 if (numElements == 2 || numElements == 4) {
1190 Type dstType = LLVM::LLVMStructType::getLiteral(
1193 return emitOpError(
"expected destination type is a structure of ")
1194 << numElements <<
" elements of type i32";
1200LogicalResult NVVM::StMatrixOp::verify() {
1201 int numMatrix = getSources().size();
1202 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1203 return emitOpError(
"expected num attribute to be 1, 2 or 4");
1206 if (m == 8 && n == 8) {
1207 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1208 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
1210 }
else if (m == 16 && n == 8) {
1211 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1212 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
1214 if (getLayout() != NVVM::MMALayout::col) {
1215 return emitOpError(
"expected layout to be col for 16x8 matrix");
1218 return emitOpError(
"expected shape to be 8x8 or 16x8");
1225 if (typeA == NVVM::WGMMATypes::tf32)
1227 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1229 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1231 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1233 if (typeA == NVVM::WGMMATypes::b1)
1239 NVVM::WGMMATypes typeA,
1240 NVVM::WGMMATypes typeB) {
1242 case NVVM::WGMMATypes::f16:
1243 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1244 typeB == NVVM::WGMMATypes::f16)
1247 case NVVM::WGMMATypes::tf32:
1248 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1251 case NVVM::WGMMATypes::u8:
1252 case NVVM::WGMMATypes::s8:
1253 if (typeD == NVVM::WGMMATypes::s32 &&
1254 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1257 case NVVM::WGMMATypes::b1:
1258 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1261 case NVVM::WGMMATypes::bf16:
1262 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1263 typeB == NVVM::WGMMATypes::bf16)
1266 case NVVM::WGMMATypes::e4m3:
1267 case NVVM::WGMMATypes::e5m2:
1268 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1269 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1272 case WGMMATypes::f32:
1273 case WGMMATypes::s32:
1274 llvm_unreachable(
"unsupported input types");
1282 72, 80, 88, 96, 104, 112, 120, 128,
1283 136, 144, 152, 160, 168, 176, 184, 192,
1284 200, 208, 216, 224, 232, 240, 248, 256};
1286 80, 96, 112, 128, 144, 160,
1287 176, 192, 208, 224, 240, 256};
1289 case WGMMATypes::f16:
1290 case WGMMATypes::tf32:
1291 case WGMMATypes::bf16:
1292 case WGMMATypes::e4m3:
1293 case WGMMATypes::e5m2:
1294 if (llvm::is_contained(allowedN, sizeN))
1297 case WGMMATypes::u8:
1298 case WGMMATypes::s8:
1299 case WGMMATypes::b1:
1300 if (llvm::is_contained(allowedNshort, sizeN))
1303 case WGMMATypes::f32:
1304 case WGMMATypes::s32:
1305 llvm_unreachable(
"unsupported input types");
1311LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1312 Value outValue = getResults();
1313 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
1315 return emitOpError() <<
"expected results to be struct";
1316 int outputSize = stype.getBody().size();
1317 WGMMATypes typeD = getTypeD();
1318 WGMMATypes typeA = getTypeA();
1319 WGMMATypes typeB = getTypeB();
1321 for (
Type t : stype.getBody()) {
1322 if (t != stype.getBody().front())
1324 <<
"all elements in struct must be same type but there is " << t;
1327 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1328 typeD != WGMMATypes::s32) {
1329 return emitOpError() <<
"does not support the given output type "
1330 << NVVM::stringifyWGMMATypes(typeD);
1332 if (typeD == WGMMATypes::s32 &&
1333 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1334 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
1338 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1339 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
1340 << NVVM::stringifyWGMMATypes(typeB)
1341 <<
", it is not supported.";
1351 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
1352 <<
" for input type "
1353 << NVVM::stringifyWGMMATypes(typeA);
1358 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
1359 <<
getShape().getN() <<
", it is not supported.";
1366 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1367 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1368 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1370 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1371 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
1372 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
1373 << stringifyWGMMATypes(typeB)
1374 <<
" requires transpose. However, this is only supported for: "
1375 << stringifyMMATypes(MMATypes::f16) <<
" and "
1376 << stringifyMMATypes(MMATypes::bf16);
1380 int expectedOutput = 0;
1381 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1382 expectedOutput =
getShape().getN() / 2;
1383 if (typeD == WGMMATypes::f16)
1384 expectedOutput =
getShape().getN() / 4;
1385 if (outputSize != expectedOutput) {
1386 return emitOpError() <<
"results " << expectedOutput
1387 <<
", however output struct has " << outputSize
1391 if (typeD != WGMMATypes::s32 &&
1392 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1393 NVVM::MMAIntOverflow::satfinite) {
1395 <<
" `satfinite` can be only used with s32 accumulator, however "
1396 "the current accumulator is "
1397 << NVVM::stringifyWGMMATypes(typeD);
1403std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1406 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1408 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1410 int expectedOutputRegisters = 0;
1411 if (getTypeD() == WGMMATypes::f16)
1412 expectedOutputRegisters =
getShape().getN() / 4;
1414 expectedOutputRegisters =
getShape().getN() / 2;
1417 llvm::raw_string_ostream ss(ptx);
1422 << ((expectedOutputRegisters * 2) + 2)
1424 "wgmma.mma_async.sync.aligned.m"
1425 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1426 << stringifyWGMMATypes(getTypeA()) <<
"."
1427 << stringifyWGMMATypes(getTypeB());
1428 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1429 NVVM::MMAIntOverflow::satfinite)
1433 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1434 ss <<
"$" << regCnt;
1435 if (regCnt != expectedOutputRegisters - 1)
1441 regCnt = (regCnt * 2);
1442 ss <<
" $" << (regCnt) <<
","
1443 <<
" $" << (regCnt + 1) <<
","
1445 if (getTypeD() != WGMMATypes::s32) {
1446 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1450 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1457bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1461 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1468 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1470 if (getTypeD() != WGMMATypes::s32) {
1471 asmValues.push_back(
1472 {makeConstantI32(rewriter,
1473 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1475 asmValues.push_back(
1476 {makeConstantI32(rewriter,
1477 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1481 asmValues.push_back(
1482 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1484 asmValues.push_back(
1485 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1491LogicalResult NVVM::FenceProxyOp::verify() {
1492 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1493 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1494 if (getKind() == NVVM::ProxyKind::GENERIC)
1495 return emitOpError() <<
"generic proxy not a supported proxy kind";
1496 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1497 return emitOpError() <<
"async_shared fence requires space attribute";
1499 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1500 return emitOpError() <<
"only async_shared fence can have space attribute";
1505LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1506 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1507 return emitOpError(
"uni-directional proxies only support generic for "
1508 "from_proxy attribute");
1510 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1511 return emitOpError(
"uni-directional proxies only support tensormap "
1512 "for to_proxy attribute");
1517LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1518 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1519 return emitOpError(
"uni-directional proxies only support generic for "
1520 "from_proxy attribute");
1522 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1523 return emitOpError(
"uni-directional proxies only support tensormap "
1524 "for to_proxy attribute");
1529LogicalResult NVVM::SetMaxRegisterOp::verify() {
1530 if (getRegCount() % 8)
1531 return emitOpError(
"new register size must be multiple of 8");
1532 if (getRegCount() < 24 || getRegCount() > 256)
1533 return emitOpError(
"new register size must be in between 24 to 256");
1537LogicalResult NVVM::BarrierOp::verify() {
1538 if (getNumberOfThreads() && !getBarrierId())
1540 "barrier id is missing, it should be set between 0 to 15");
1542 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
1543 return emitOpError(
"reduction are only available when id is 0");
1547 return emitOpError(
"reduction predicate and reduction operation must be "
1548 "specified together");
1553LogicalResult NVVM::Tcgen05CpOp::verify() {
1554 auto mc = getMulticast();
1556 using SH = Tcgen05CpShape;
1557 using MC = Tcgen05CpMulticast;
1559 case SH::SHAPE_128x256b:
1560 case SH::SHAPE_128x128b:
1561 case SH::SHAPE_4x256b:
1563 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1565 case SH::SHAPE_64x128b:
1566 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1567 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1568 "warpx2_02_13 for tcgen05.cp Op");
1570 case SH::SHAPE_32x128b:
1571 if (mc != MC::WARPX4)
1573 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1579LogicalResult NVVM::MatchSyncOp::verify() {
1580 if (getKind() == NVVM::MatchSyncKind::all) {
1581 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1582 if (!type || type.getBody().size() != 2 ||
1583 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1584 return emitOpError(
"match.sync 'all' returns a two element struct with "
1585 "first element as i32 and second element as i1");
1588 if (!
getType().isInteger(32)) {
1589 return emitOpError(
"match.sync 'any' returns an i32");
1595LogicalResult NVVM::VoteSyncOp::verify() {
1596 if (getKind() == NVVM::VoteSyncKind::ballot) {
1597 if (!
getType().isInteger(32)) {
1598 return emitOpError(
"vote.sync 'ballot' returns an i32");
1601 if (!
getType().isInteger(1)) {
1602 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1608LogicalResult NVVM::PrefetchOp::verify() {
1609 using MemSpace = NVVM::NVVMMemorySpace;
1610 using CacheLevel = NVVM::PrefetchCacheLevel;
1612 unsigned addressSpace =
1613 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1614 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1615 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1617 if (getTensormap() && cacheLevel)
1618 return emitOpError(
"cannot specify both tensormap and cache level");
1620 if (getTensormap()) {
1621 if (addressSpace != MemSpace::Generic &&
1622 addressSpace != MemSpace::Constant) {
1624 "prefetch tensormap requires a generic or constant pointer");
1627 if (evictPriority) {
1629 "prefetch tensormap does not support eviction priority");
1632 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1634 "in_param_space can only be specified for a generic pointer");
1637 }
else if (cacheLevel) {
1638 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1639 addressSpace != MemSpace::Local) {
1640 return emitOpError(
"prefetch to cache level requires a generic, global, "
1641 "or local pointer");
1645 if (*cacheLevel != CacheLevel::L1) {
1647 "unsupported cache level, the only supported uniform "
1648 "cache level is L1");
1651 if (addressSpace != MemSpace::Generic) {
1653 "prefetch to uniform cache requires a generic pointer");
1657 if (evictPriority) {
1658 if (*cacheLevel != CacheLevel::L2)
1660 "cache eviction priority supported only for cache level L2");
1662 if (addressSpace != MemSpace::Global)
1663 return emitOpError(
"cache eviction priority requires a global pointer");
1665 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1666 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1668 "unsupported cache eviction priority, only evict_last and "
1669 "evict_normal are supported");
1673 return emitOpError(
"predicate supported only on prefetch tensormap");
1677 "requires specification of either cache level or tensormap");
1683LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
1684 switch (getQueryType()) {
1685 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1687 return emitOpError(
"is_canceled query type returns an i1");
1689 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1690 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1691 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1692 if (!
getType().isInteger(32)) {
1693 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
1694 "get_first_cta_id_z query types return an i32");
1701LogicalResult NVVM::ReduxOp::verify() {
1704 if (!reduxType.
isF32()) {
1706 return emitOpError(
"abs attribute is supported only for f32 type");
1708 return emitOpError(
"nan attribute is supported only for f32 type");
1711 NVVM::ReduxKind kind = getKind();
1713 case NVVM::ReduxKind::ADD:
1714 case NVVM::ReduxKind::AND:
1715 case NVVM::ReduxKind::OR:
1716 case NVVM::ReduxKind::XOR:
1717 case NVVM::ReduxKind::MAX:
1718 case NVVM::ReduxKind::MIN:
1719 case NVVM::ReduxKind::UMAX:
1720 case NVVM::ReduxKind::UMIN:
1723 << stringifyEnum(kind) <<
"' redux kind unsupported with "
1724 << reduxType <<
" type. Only supported type is 'i32'.";
1726 case NVVM::ReduxKind::FMIN:
1727 case NVVM::ReduxKind::FMAX:
1728 if (!reduxType.
isF32())
1730 << stringifyEnum(kind) <<
"' redux kind unsupported with "
1731 << reduxType <<
" type. Only supported type is 'f32'.";
1744 unsigned sizeInBits,
1746 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1748 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1749 if (mask != 0xffffffffu)
1750 field = builder.CreateAnd(field, builder.getInt32(mask));
1752 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1753 field = builder.CreateShl(field, start);
1755 return builder.CreateOr(
result, field);
1758void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1760 llvm::IRBuilderBase &builder) {
1761 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1762 llvm::Value *smemDesc = builder.getInt64(0);
1767 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1769 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1775 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1779 mt.
mapValue(thisOp.getRes()) = smemDesc;
1786std::string NVVM::MBarrierInitOp::getPtx() {
1788 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
1789 : std::string(
"mbarrier.init.b64 [%0], %1;");
1792std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1795 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1796 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
1799std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1801 llvm::StringRef space = isShared ?
".shared" :
"";
1803 return llvm::formatv(
"{\n\t"
1804 ".reg .pred P1; \n\t"
1806 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1807 "@P1 bra.uni DONE; \n\t"
1808 "bra.uni LAB_WAIT; \n\t"
1820 auto thisOp = cast<NVVM::BarrierOp>(op);
1821 llvm::Value *barrierId = thisOp.getBarrierId()
1823 : builder.getInt32(0);
1824 llvm::Intrinsic::ID id;
1826 if (thisOp.getNumberOfThreads()) {
1827 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1828 args.push_back(barrierId);
1829 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
1830 }
else if (thisOp.getReductionOp()) {
1831 switch (*thisOp.getReductionOp()) {
1832 case NVVM::BarrierReduction::AND:
1833 id = llvm::Intrinsic::nvvm_barrier0_and;
1835 case NVVM::BarrierReduction::OR:
1836 id = llvm::Intrinsic::nvvm_barrier0_or;
1838 case NVVM::BarrierReduction::POPC:
1839 id = llvm::Intrinsic::nvvm_barrier0_popc;
1842 args.push_back(mt.
lookupValue(thisOp.getReductionPredicate()));
1844 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1845 args.push_back(barrierId);
1848 return {id, std::move(args)};
1853 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1855 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1856 : llvm::Intrinsic::nvvm_mbarrier_init;
1861 args.push_back(mt.
lookupValue(thisOp.getCount()));
1863 return {id, std::move(args)};
1868 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1870 llvm::Intrinsic::ID
id = isShared
1871 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1872 : llvm::Intrinsic::nvvm_mbarrier_inval;
1879 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
1881 llvm::Intrinsic::ID
id = isShared
1882 ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
1883 : llvm::Intrinsic::nvvm_mbarrier_arrive;
1890 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
1892 llvm::Intrinsic::ID
id =
1893 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
1894 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
1898 args.push_back(mt.
lookupValue(thisOp.getCount()));
1900 return {id, std::move(args)};
1905 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
1907 llvm::Intrinsic::ID
id = isShared
1908 ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
1909 : llvm::Intrinsic::nvvm_mbarrier_test_wait;
1913 args.push_back(mt.
lookupValue(thisOp.getState()));
1915 return {id, std::move(args)};
1920 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
1923 llvm::Intrinsic::ID id;
1924 if (thisOp.getNoinc()) {
1925 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
1926 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
1928 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
1929 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
1935#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1936 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1938#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1939 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1944 llvm::Intrinsic::ID id;
1946 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1947 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1948 switch (cpAsyncOp.getSize()) {
1956 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1961 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1965 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1966 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1968 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1975 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1977 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1980 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1984 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1985 llvm::Value *i64Unused =
1986 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
1987 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1988 args.push_back(builder.getInt1(hasCacheHint));
1990 return {id, std::move(args)};
1995 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1999 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
2001 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2005 mlir::Value multicastMask = thisOp.getMulticastMask();
2006 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
2009 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
2010 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask)
2016 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2017 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
2018 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2022 args.push_back(builder.getInt1(hasMulticastMask));
2023 args.push_back(builder.getInt1(hasCacheHint));
2025 llvm::Intrinsic::ID
id =
2027 ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
2028 : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
2030 return {id, std::move(args)};
2035 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
2037 llvm::Intrinsic::ID
id =
2038 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
2041 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
2042 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2046 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2047 llvm::Value *i64Unused =
2048 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2049 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2050 args.push_back(builder.getInt1(hasCacheHint));
2053 if (
mlir::Value byteMask = thisOp.getByteMask()) {
2055 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
2058 return {id, std::move(args)};
2061bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
2068 for (
auto val : getOperands())
2075CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2077 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
2078 const bool isCTAOnly = thisOp.getIsCTAOnly();
2082 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
2084 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2094 const bool hasMC =
static_cast<bool>(mcMask);
2095 llvm::Value *i16Zero =
2096 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
2100 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2101 llvm::Value *i64Zero =
2102 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2108 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
2110 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
2114 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
2115 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
2116 args.push_back(builder.getInt1(hasMC));
2117 args.push_back(builder.getInt1(hasCacheHint));
2121 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
2122 args.push_back(builder.getInt1(hasCacheHint));
2125 constexpr size_t numDims = 5;
2126 constexpr size_t numModes = 5;
2127 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
2128 using TableTy = std::array<rowTy, numModes>;
2129 static constexpr TableTy IDTable{
2130 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
2131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
2132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
2133 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
2134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
2136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
2137 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
2138 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
2140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
2141 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
2142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
2144 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
2145 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
2146 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
2148 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
2150 static constexpr TableTy IDTableCTA{
2152 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
2153 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
2154 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
2155 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
2156 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
2158 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
2159 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
2160 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
2162 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
2163 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
2164 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
2166 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
2167 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
2168 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
2170 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
2173 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
2174 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
2175 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
2176 size_t mode =
static_cast<size_t>(thisOp.getMode());
2177 size_t dim = thisOp.getCoordinates().size();
2178 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
2180 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
2182 return {id, std::move(args)};
2187 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
2191 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2193 for (
auto v : thisOp.getCoordinates())
2195 for (
auto v : thisOp.getIm2colOffsets())
2199 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2200 llvm::Value *i64Unused =
2201 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2202 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2203 args.push_back(builder.getInt1(hasCacheHint));
2205 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2206 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2207 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
2208 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
2209 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
2210 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
2211 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
2213 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
2214 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
2215 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
2217 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
2218 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
2219 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
2221 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
2222 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
2223 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
2224 {NI, NI, NI, NI, NI,
2225 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
2227 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
2228 "TMALoadModes must match number of rows in IDTable");
2229 size_t mode =
static_cast<size_t>(thisOp.getMode());
2230 size_t dim = thisOp.getCoordinates().size();
2231 llvm::Intrinsic::ID
id = IDTable[mode][dim];
2232 if (
id == llvm::Intrinsic::not_intrinsic)
2233 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
2235 return {id, std::move(args)};
2239CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2241 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
2245 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2246 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2248 for (
auto v : thisOp.getCoordinates())
2252 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2253 llvm::Value *i64Unused =
2254 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2255 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2256 args.push_back(builder.getInt1(hasCacheHint));
2258 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2259 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2260 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
2261 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
2262 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
2263 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
2264 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
2265 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
2266 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
2267 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
2268 {NI, NI, NI, NI, NI,
2269 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
2271 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2272 "TMAStoreModes must match number of rows in IDTable");
2273 size_t mode =
static_cast<size_t>(thisOp.getMode());
2274 size_t dim = thisOp.getCoordinates().size();
2275 llvm::Intrinsic::ID
id = IDTable[mode][dim];
2276 if (
id == llvm::Intrinsic::not_intrinsic)
2278 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2280 return {id, std::move(args)};
2285 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2293 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2294 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2296 for (
Value v : thisOp.getCoordinates())
2300 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2301 llvm::Value *i64ZeroValue =
2302 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2303 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
2304 args.push_back(builder.getInt1(hasCacheHint));
2306 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
2308 constexpr unsigned numRedKinds = 8;
2309 constexpr unsigned numLayouts = 2;
2310 constexpr unsigned maxDim = 5;
2311 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2312 using layoutTable = std::array<row, numLayouts>;
2313 using fullTable = std::array<layoutTable, numRedKinds>;
2314 static constexpr fullTable IDTable{
2317 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2318 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2319 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2320 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2321 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2323 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2324 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2330 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2331 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2332 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2334 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2342 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2343 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2345 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2346 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2347 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2350 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2351 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2352 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2353 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2354 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2356 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2357 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2358 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2361 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2362 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2363 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2364 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2365 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2367 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2368 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2369 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2372 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2373 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2374 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2375 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2376 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2378 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2379 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2380 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2383 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2384 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2385 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2386 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2387 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2389 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2390 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2391 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2394 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2395 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2396 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2397 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2398 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2400 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2401 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2403 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2405 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2406 "TMAReduxKinds must match number of rows in IDTable");
2408 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
2409 size_t mode =
static_cast<size_t>(thisOp.getMode());
2410 size_t dim = thisOp.getCoordinates().size();
2412 assert(redKind < IDTable.size() &&
2413 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2414 assert(mode < IDTable[redKind].size() &&
2415 "Invalid mode for CpAsyncBulkTensorReduceOp");
2416 assert(dim < IDTable[redKind][mode].size() &&
2417 "Invalid dim for CpAsyncBulkTensorReduceOp");
2419 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2422 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2424 return {intrinsicID, std::move(args)};
2429#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2430 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2431 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2433#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2434 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2435 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2438ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2439 NVVM::SaturationMode sat,
bool hasRelu) {
2440 using RndMode = NVVM::FPRoundingMode;
2441 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2450 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
2455ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2457 llvm::IRBuilderBase &builder) {
2462 bool hasRelu = op.getRelu();
2464 llvm::Intrinsic::ID intId =
2465 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2466 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2468 return {intId, std::move(args)};
2471#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2472 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2473 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2475llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
2478 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2481 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2485 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
2486 return llvm::Intrinsic::not_intrinsic;
2490#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2491 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2492 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2494#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2495 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2496 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2499ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2500 NVVM::SaturationMode sat,
bool hasRelu) {
2501 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2502 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2503 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2506 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2509 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2512 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2513 if (hasRoundingModeRZ)
2515 else if (hasRoundingModeRP)
2518 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2521 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2522 return llvm::Intrinsic::not_intrinsic;
2526#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2527 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2528 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2530llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
2533 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2536 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2540 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
2541 return llvm::Intrinsic::not_intrinsic;
2545#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2546 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2547 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2550ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2551 NVVM::SaturationMode sat) {
2552 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2554 case NVVM::FPRoundingMode::RZ:
2556 case NVVM::FPRoundingMode::RP:
2559 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
2565 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2567 bool hasRelu = curOp.getRelu();
2569 llvm::Intrinsic::ID intId =
2571 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2572 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2573 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2575 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2576 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2577 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2580 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
2581 return llvm::Intrinsic::not_intrinsic;
2584 llvm::Value *packedI16 =
2585 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2586 llvm::Type::getInt16Ty(builder.getContext()));
2588 return {intId, {packedI16}};
2593 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2595 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2596 llvm::Value *packedI16 =
2597 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2598 llvm::Type::getInt16Ty(builder.getContext()));
2600 return {intId, {packedI16}};
2605 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2607 bool hasRelu = curOp.getRelu();
2609 llvm::Intrinsic::ID intId =
2611 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2612 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2613 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2615 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2616 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2617 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2620 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
2621 return llvm::Intrinsic::not_intrinsic;
2624 llvm::Value *packedI16 =
2625 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2626 llvm::Type::getInt16Ty(builder.getContext()));
2628 return {intId, {packedI16}};
2633 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2635 bool hasRelu = curOp.getRelu();
2637 llvm::Intrinsic::ID intId =
2639 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2640 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2641 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2644 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
2645 return llvm::Intrinsic::not_intrinsic;
2648 llvm::Value *extendedI16 =
2649 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
2650 llvm::Type::getInt16Ty(builder.getContext()));
2652 return {intId, {extendedI16}};
2656Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
2659 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2660 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2662 bool isShared = as == NVVMMemorySpace::Shared;
2663 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2665 llvm::Intrinsic::ID id;
2667 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2668 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2670 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2671 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2681llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2684 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2685 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2686 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2687 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2696#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2697 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2698 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2700#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2701 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2702 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2705Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
2708 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2709 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2711 bool isShared = as == NVVMMemorySpace::Shared;
2712 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
2713 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2715 llvm::Intrinsic::ID
id =
2722 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
2727#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2728 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2730#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2731 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2732 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2734#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2736 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2737 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2738 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2739 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2740 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2743llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2744 bool hasRelu = getRelu();
2745 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2747 if (hasRelu && hasSatFinite)
2748 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2750 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2752 return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2753 return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2756llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2757 bool hasRelu = getRelu();
2758 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2760 if (hasRelu && hasSatFinite)
2761 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2763 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2765 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2766 return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2769llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
2771 bool hasRelu = getRelu();
2774 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2775 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
2776 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
2778 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2779 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
2780 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
2783 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
2784 return llvm::Intrinsic::not_intrinsic;
2788llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
2790 bool hasRelu = getRelu();
2793 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2794 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
2795 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
2797 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2798 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
2799 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
2802 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
2803 return llvm::Intrinsic::not_intrinsic;
2807llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
2809 bool hasRelu = getRelu();
2812 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
2813 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
2814 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
2817 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
2818 return llvm::Intrinsic::not_intrinsic;
2822llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
2823 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2824 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2825 auto srcFmt = curOp.getSrcFormat();
2826 auto mc = curOp.getMulticast();
2828 switch (curOp.getShape()) {
2829 case Tcgen05CpShape::SHAPE_128x256b:
2831 case Tcgen05CpShape::SHAPE_128x128b:
2833 case Tcgen05CpShape::SHAPE_4x256b:
2835 case Tcgen05CpShape::SHAPE_32x128b:
2837 case Tcgen05CpShape::SHAPE_64x128b:
2838 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2842 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
2849 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2851 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2856LogicalResult Tcgen05LdOp::verify() {
2858 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2861 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
2862 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
2864 auto resTy = getRes().getType();
2865 unsigned resLen = isa<VectorType>(resTy)
2866 ? llvm::cast<VectorType>(resTy).getNumElements()
2869 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
2870 "{1} in tcgen05.ld Op",
2871 resLen, stringifyEnum(
getShape())));
2876LogicalResult Tcgen05StOp::verify() {
2878 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2881 auto valTy = getVal().getType();
2882 unsigned valLen = isa<VectorType>(valTy)
2883 ? llvm::cast<VectorType>(valTy).getNumElements()
2886 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
2887 "{1} in tcgen05.st Op",
2888 valLen, stringifyEnum(
getShape())));
2898 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
2899 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2900 rangeAttr.getLower(), rangeAttr.getUpper()});
2908 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2912 const llvm::APInt &lower = rangeAttr->getLower();
2913 const llvm::APInt &upper = rangeAttr->getUpper();
2916 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2917 unsigned bitWidth = lower.getBitWidth();
2918 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2919 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2921 "invalid range attribute: Lower == Upper, but they aren't min (")
2922 << llvm::toString(minVal, 10,
false) <<
") or max ("
2923 << llvm::toString(maxVal, 10,
false)
2924 <<
") value! This is an invalid constant range.";
2931 llvm::IRBuilderBase &builder) {
2932 return builder.CreateBitCast(arg,
2933 llvm::Type::getInt32Ty(builder.getContext()));
2938 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2945 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2946 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2947 unsigned type = (isASigned << 1) | isBSigned;
2948 const llvm::Intrinsic::ID ids[] = {
2949 llvm::Intrinsic::nvvm_idp4a_u_u,
2950 llvm::Intrinsic::nvvm_idp4a_u_s,
2951 llvm::Intrinsic::nvvm_idp4a_s_u,
2952 llvm::Intrinsic::nvvm_idp4a_s_s,
2954 return {ids[type], args};
2959 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2964 args.push_back(builder.getInt1(curOp.getBHi()));
2967 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2968 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2969 unsigned type = (isASigned << 1) | isBSigned;
2970 const llvm::Intrinsic::ID ids[] = {
2971 llvm::Intrinsic::nvvm_idp2a_u_u,
2972 llvm::Intrinsic::nvvm_idp2a_u_s,
2973 llvm::Intrinsic::nvvm_idp2a_s_u,
2974 llvm::Intrinsic::nvvm_idp2a_s_s,
2976 return {ids[type], args};
2980 llvm::IRBuilderBase &builder) {
2981 return builder.CreateAddrSpaceCast(
2983 llvm::PointerType::get(builder.getContext(),
2984 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2988PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2990 llvm::IRBuilderBase &builder) {
2991 using MemSpace = NVVM::NVVMMemorySpace;
2992 using CacheLevel = NVVM::PrefetchCacheLevel;
2994 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2995 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2996 op.getEvictPriority();
2997 unsigned addressSpace =
2998 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
3006 if (op.getTensormap())
3007 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
3009 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
3011 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
3012 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
3014 if (evictPriority && *cacheLevel == CacheLevel::L2) {
3015 switch (*evictPriority) {
3016 case NVVM::CacheEvictionPriority::EvictLast:
3017 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
3018 case NVVM::CacheEvictionPriority::EvictNormal:
3019 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
3021 llvm_unreachable(
"Invalid cache eviction priority");
3025 switch (
static_cast<MemSpace
>(addressSpace)) {
3026 case MemSpace::Generic:
3027 return *cacheLevel == CacheLevel::L1
3029 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
3030 case MemSpace::Global:
3031 return *cacheLevel == CacheLevel::L1
3033 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
3035 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
3036 case MemSpace::Local:
3037 return *cacheLevel == CacheLevel::L1
3039 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
3041 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
3043 llvm_unreachable(
"Invalid pointer address space");
3047bool NVVM::InlinePtxOp::getAsmValues(
3051 for (
auto arg : getReadWriteArgs())
3053 for (
auto arg : getResults())
3055 for (
auto arg : getReadOnlyArgs())
3062NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
3064 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
3066 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
3067 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
3069 llvm::Intrinsic::ID intrinsicID =
3070 curOp.getMulticast()
3072 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
3073 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
3075 return {intrinsicID, args};
3078NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
3080 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
3082 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
3084 llvm::Intrinsic::ID intrinsicID;
3086 switch (curOp.getQueryType()) {
3087 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3089 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
3091 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3092 intrinsicID = llvm::Intrinsic::
3093 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
3095 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3096 intrinsicID = llvm::Intrinsic::
3097 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
3099 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3100 intrinsicID = llvm::Intrinsic::
3101 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
3104 return {intrinsicID, args};
3113 llvm::IRBuilderBase &builder) {
3115 auto thisOp = cast<NVVM::Tcgen05MMAOp>(op);
3118 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3121 const bool isATensor = isa<llvm::PointerType>(
A->getType());
3124 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3125 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3126 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3128 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
3129 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
3130 using IsATensorArray = std::array<CtaGroupArray, 2>;
3131 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
3132 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
3135 static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {
3141 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic},
3143 {llvm::Intrinsic::nvvm_tcgen05_mma_shared,
notIntrinsic}}},
3147 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
3148 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
3152 llvm::Intrinsic::nvvm_tcgen05_mma_tensor,
3153 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift,
3159 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic},
3161 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d,
notIntrinsic}}},
3165 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
3166 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
3170 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d,
3171 llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift,
3177 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1,
3180 {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2,
3185 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1,
3187 nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift,
3192 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2,
3194 nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift,
3200 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1,
3204 nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2,
3209 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1,
3211 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift},
3215 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2,
3217 nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift,
3220 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
3221 bool hasScaleInputD = ScaleInputD !=
nullptr;
3223 llvm::Value *DisableOutputLane =
3225 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
3227 const unsigned ctaGroup =
3230 llvm::Intrinsic::ID ID =
3231 tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
3232 [ctaGroup - 1][thisOp.getAShift()];
3234 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMAOp.");
3237 args.push_back(ScaleInputD);
3239 if (hasDisableOutputLane)
3240 args.push_back(DisableOutputLane);
3242 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
3244 if (!hasDisableOutputLane)
3245 args.push_back(builder.getInt32(ctaGroup));
3248 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3255 NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
3256 NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc) {
3258 if (disableOutputLane) {
3259 mlir::VectorType disableOutputLaneType =
3260 cast<mlir::VectorType>(disableOutputLane.
getType());
3261 if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 &&
3262 disableOutputLaneType.getNumElements() != 4) ||
3263 (ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
3264 disableOutputLaneType.getNumElements() != 8))
3265 return emitError(loc) <<
"Disable Output Lane of length "
3266 << disableOutputLaneType.getNumElements()
3267 <<
" is incompatible with CtaGroupAttr";
3270 if (hasAShift && !isATensor)
3272 loc,
"A-shift can be applied only when matrix A is in tensor memory");
3274 if (hasAShift ==
true && (collectorOp == Tcgen05MMACollectorOp::FILL ||
3275 collectorOp == Tcgen05MMACollectorOp::USE))
3277 loc,
"Cannot use collector buffer operation fill or use with ashift");
3282LogicalResult Tcgen05MMAOp::verify() {
3284 getDisableOutputLane(), getCtaGroup(), getAShift(),
3285 getCollectorOp(), getLoc());
3295 auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op);
3298 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3301 bool isATensor = isa<llvm::PointerType>(
A->getType());
3304 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3305 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3306 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3307 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
3309 using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
3310 using CtaGroupArray = std::array<EnableAShiftArray, 2>;
3311 using IsATensorArray = std::array<CtaGroupArray, 2>;
3312 using HasScaleInputDArray = std::array<IsATensorArray, 2>;
3313 using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;
3316 static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = {
3322 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic},
3324 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared,
notIntrinsic}}},
3328 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
3329 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
3333 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor,
3334 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift,
3340 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
3343 {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d,
3348 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
3349 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
3353 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d,
3354 llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift,
3361 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1,
3365 nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2,
3370 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1,
3372 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift,
3377 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2,
3379 nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift,
3385 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1,
3389 nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2,
3394 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1,
3396 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift},
3400 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2,
3402 nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift,
3405 llvm::Value *ScaleInputD = mt.
lookupValue(thisOp.getScaleInputD());
3406 bool hasScaleInputD = ScaleInputD !=
nullptr;
3408 llvm::Value *DisableOutputLane =
3410 bool hasDisableOutputLane = DisableOutputLane !=
nullptr;
3415 llvm::Intrinsic::ID ID =
3416 tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor]
3417 [ctaGroup - 1][thisOp.getAShift()];
3419 assert(ID !=
notIntrinsic &&
"Invalid intrinsic for Tcgen05MMASparseOp.");
3422 args.push_back(ScaleInputD);
3424 if (hasDisableOutputLane)
3425 args.push_back(DisableOutputLane);
3427 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
3429 if (!hasDisableOutputLane)
3430 args.push_back(builder.getInt32(ctaGroup));
3433 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3438LogicalResult Tcgen05MMASparseOp::verify() {
3440 getDisableOutputLane(), getCtaGroup(), getAShift(),
3441 getCollectorOp(), getLoc());
3451 auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op);
3454 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3457 bool isATensor = isa<llvm::PointerType>(
A->getType());
3460 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3461 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3462 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3463 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
3464 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
3465 args.push_back(builder.getInt32(
3468 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3470 auto kind = thisOp.getKind();
3471 auto blockScale = thisOp.getBlockScale();
3472 llvm::Intrinsic::ID ID = [&]() {
3473 if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
3474 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3475 return isATensor ? llvm::Intrinsic::
3476 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
3478 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale;
3479 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3482 nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32
3484 nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
3486 }
else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
3487 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3489 ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
3490 : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale;
3491 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3492 return isATensor ? llvm::Intrinsic::
3493 nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32
3495 nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
3497 }
else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
3498 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3501 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32
3503 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32;
3505 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
3508 nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16
3510 nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16;
3513 llvm_unreachable(
"Invalid tcgen05.mma.block_scale attributes");
3521 NVVM::Tcgen05MMABlockScaleKind kind,
3522 NVVM::Tcgen05MMABlockScale blockScale,
3525 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
3526 kind == Tcgen05MMABlockScaleKind::MXF4NVF4)
3527 return emitError(loc,
"mxf4nvf4 requires block scale attribute");
3529 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
3530 kind != Tcgen05MMABlockScaleKind::MXF4NVF4)
3532 llvm::formatv(
"{} kind does not support block16 attribute",
3533 stringifyEnum(kind)));
3538LogicalResult Tcgen05MMABlockScaleOp::verify() {
3540 getBlockScale(), getLoc());
3550 auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op);
3553 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3556 bool isATensor = isa<llvm::PointerType>(
A->getType());
3559 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3560 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3561 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3562 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
3563 args.push_back(mt.
lookupValue(thisOp.getScaleA()));
3564 args.push_back(mt.
lookupValue(thisOp.getScaleB()));
3565 args.push_back(builder.getInt32(
3568 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3570 auto kind = thisOp.getKind();
3571 auto blockScale = thisOp.getBlockScale();
3572 llvm::Intrinsic::ID ID = [&]() {
3573 if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) {
3574 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3575 return isATensor ? llvm::Intrinsic::
3576 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
3578 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale;
3579 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3582 nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32
3584 nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
3586 }
else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) {
3587 if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
3588 return isATensor ? llvm::Intrinsic::
3589 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
3591 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale;
3592 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3595 nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32
3597 nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
3599 }
else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) {
3600 if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
3603 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32
3605 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32;
3607 }
else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) {
3610 nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16
3612 nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16;
3615 llvm_unreachable(
"Invalid tcgen05.mma.sp.block_scale attributes");
3621LogicalResult Tcgen05MMASparseBlockScaleOp::verify() {
3623 getBlockScale(), getLoc());
3633 auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op);
3636 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3639 bool isATensor = isa<llvm::PointerType>(
A->getType());
3642 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3643 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3644 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3646 mlir::Value ZeroColMask = thisOp.getZeroColMask();
3650 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask
3651 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask;
3653 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor
3654 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared;
3656 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
3658 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
3660 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3672 auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op);
3675 args.push_back(mt.
lookupValue(thisOp.getMatrixD()));
3678 bool isATensor = isa<llvm::PointerType>(
A->getType());
3681 args.push_back(mt.
lookupValue(thisOp.getMatrixB()));
3682 args.push_back(mt.
lookupValue(thisOp.getIdesc()));
3683 args.push_back(mt.
lookupValue(thisOp.getEnableInputD()));
3684 args.push_back(mt.
lookupValue(thisOp.getSparseMetadata()));
3686 mlir::Value ZeroColMask = thisOp.getZeroColMask();
3691 ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask
3692 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask;
3694 ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor
3695 : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared;
3697 args.push_back(builder.getInt32(
static_cast<unsigned>(thisOp.getKind())));
3699 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorBBuffer())));
3701 builder.getInt32(
static_cast<unsigned>(thisOp.getCollectorOp())));
3711void NVVMDialect::initialize() {
3714#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3717#define GET_ATTRDEF_LIST
3718#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
3723 allowUnknownOperations();
3724 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
3725 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
3728LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
3730 StringAttr attrName = attr.
getName();
3732 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
3733 if (!isa<LLVM::LLVMFuncOp>(op)) {
3734 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
3735 <<
"' attribute attached to unexpected op";
3740 if (attrName == NVVMDialect::getMaxntidAttrName() ||
3741 attrName == NVVMDialect::getReqntidAttrName() ||
3742 attrName == NVVMDialect::getClusterDimAttrName()) {
3743 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
3744 if (!values || values.empty() || values.size() > 3) {
3747 <<
"' attribute must be integer array with maximum 3 index";
3752 if (attrName == NVVMDialect::getMinctasmAttrName() ||
3753 attrName == NVVMDialect::getMaxnregAttrName() ||
3754 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
3755 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
3757 <<
"'" << attrName <<
"' attribute must be integer constant";
3761 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
3762 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
3763 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
3765 <<
"'" << attrName <<
"' attribute must be used along with "
3766 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
3767 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
3774LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
3775 unsigned regionIndex,
3778 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3782 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
3783 StringAttr attrName = argAttr.
getName();
3784 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
3788 <<
"' attribute must be present only on kernel arguments";
3790 if (!isa<UnitAttr>(argAttr.
getValue()))
3791 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
3792 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
3795 <<
"' attribute requires the argument to also have attribute '"
3796 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
3807unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
3808 return static_cast<unsigned>(getValue());
3811bool NVVMMemorySpaceAttr::isValidLoad(
3812 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3813 const ::mlir::DataLayout *dataLayout,
3819bool NVVMMemorySpaceAttr::isValidStore(
3820 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3821 const ::mlir::DataLayout *dataLayout,
3827bool NVVMMemorySpaceAttr::isValidAtomicOp(
3828 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
3829 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
3832 assert(
false &&
"unimplemented, see TODO in the source.");
3836bool NVVMMemorySpaceAttr::isValidAtomicXchg(
3837 Type type, ptr::AtomicOrdering successOrdering,
3838 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
3839 const ::mlir::DataLayout *dataLayout,
3842 assert(
false &&
"unimplemented, see TODO in the source.");
3846bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
3850 assert(
false &&
"unimplemented, see TODO in the source.");
3854bool NVVMMemorySpaceAttr::isValidPtrIntCast(
3859 assert(
false &&
"unimplemented, see TODO in the source.");
3868 int optLevel, StringRef triple, StringRef chip,
3869 StringRef features, DictionaryAttr flags,
3871 if (optLevel < 0 || optLevel > 3) {
3872 emitError() <<
"The optimization level must be a number between 0 and 3.";
3875 if (triple.empty()) {
3876 emitError() <<
"The target triple cannot be empty.";
3880 emitError() <<
"The target chip cannot be empty.";
3883 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
3884 return mlir::isa_and_nonnull<StringAttr>(attr);
3886 emitError() <<
"All the elements in the `link` array must be strings.";
3892LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
3893 if (!getVerifyTarget())
3896 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
3899 "NVVM target attribute must be attached to a GPU module");
3902 const NVVMCheckSMVersion targetSMVersion =
3906 "Minimum NVVM target SM version is sm_20");
3909 gpuModuleOp->walk([&](Operation *op) {
3910 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
3911 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
3913 op->
emitOpError() <<
"is not supported on " << getChip();
3923#define GET_OP_CLASSES
3924#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3926#define GET_ATTRDEF_CLASSES
3927#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static LogicalResult verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.