31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/IR/IRBuilder.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/FormatVariadic.h"
36#include "llvm/Support/NVPTXAddrSpace.h"
37#include "llvm/Support/raw_ostream.h"
45#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
55 auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(
ptr.getType());
56 return ptrTy.getAddressSpace() ==
static_cast<unsigned>(targetAS);
72 size_t numIm2ColOffsets,
74 if (tensorDims < 1 || tensorDims > 5)
75 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
83 "to use im2col mode, the tensor has to be at least 3-dimensional");
85 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
87 loc,
"im2col offsets must be 2 less than number of coordinates");
92LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
93 TMAStoreMode mode = getMode();
98 if (mode != TMAStoreMode::TILE)
99 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
100 if (getL2CacheHint())
101 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
106 case TMAStoreMode::TILE:
108 case TMAStoreMode::IM2COL:
110 case TMAStoreMode::TILE_SCATTER4:
112 return emitError(
"Scatter4 mode expects 5 coordinates");
117LogicalResult CpAsyncOp::verify() {
118 if (getModifier() != LoadCacheModifierKind::CG &&
119 getModifier() != LoadCacheModifierKind::CA)
120 return emitError(
"Only CG and CA cache modifiers are supported.");
121 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122 return emitError(
"expected byte size to be either 4, 8 or 16.");
123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
131 if (tensorDims < 1 || tensorDims > 5)
132 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
134 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
135 size_t expectedIm2colOff) -> LogicalResult {
136 if (isIm2col && (tensorDims < 3))
138 <<
"to use " << stringifyEnum(mode)
139 <<
" mode, the tensor has to be at least 3-dimensional";
141 if (numIm2colOff != expectedIm2colOff)
142 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
143 <<
" (provided " << numIm2colOff <<
")";
149 case TMALoadMode::TILE:
150 return checkTMALoadParams(mode,
false, 0);
151 case TMALoadMode::IM2COL:
152 return checkTMALoadParams(mode,
true, tensorDims - 2);
153 case TMALoadMode::IM2COL_W:
154 case TMALoadMode::IM2COL_W_128:
155 return checkTMALoadParams(mode,
true, 2);
156 case TMALoadMode::TILE_GATHER4:
157 return (tensorDims == 5)
158 ? checkTMALoadParams(mode,
false, 0)
159 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
164LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
166 getMode(), getLoc());
169LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
170 TMALoadMode mode = getMode();
171 bool isCTAOnly = getIsCTAOnly();
172 if (getPredicate()) {
174 return emitError(
"Predicate is supported only for shared::cluster mode.");
175 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
177 "Predicate is supported only for Tile and Im2col modes.");
179 NVVMMemorySpace expectedAS =
180 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
181 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
183 if (AS != expectedAS)
186 ?
"Shared::cta destination requires address-space 3."
187 :
"Shared::cluster destination requires address-space 7.");
190 if (getMulticastMask())
191 return emitError(
"Multicast is not supported with shared::cta mode.");
193 return emitError(
"CTAGroup is not supported with shared::cta mode.");
198 getMode(), getLoc());
201LogicalResult CpAsyncBulkTensorReduceOp::verify() {
202 TMAStoreMode mode = getMode();
205 case TMAStoreMode::TILE:
207 case TMAStoreMode::IM2COL:
209 case TMAStoreMode::TILE_SCATTER4:
210 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
215LogicalResult ConvertFloatToTF32Op::verify() {
216 using RndMode = NVVM::FPRoundingMode;
220 return emitError(
"Relu not supported with rna rounding mode.");
227 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
232LogicalResult ConvertF32x2ToF6x2Op::verify() {
235 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
237 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
238 << mlir::Float6E3M2FNType::get(ctx)
239 <<
" types are supported for conversions from f32x2 to f6x2.";
244LogicalResult ConvertF32x2ToF8x2Op::verify() {
245 using RndMode = NVVM::FPRoundingMode;
246 using SatMode = NVVM::SaturationMode;
248 bool isRoundingModeRN = getRnd() == RndMode::RN;
249 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
250 bool isRoundingModeRP = getRnd() == RndMode::RP;
251 bool isSatFinite = getSat() == SatMode::SATFINITE;
253 bool hasRelu = getRelu();
258 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
260 if (!isRoundingModeRN) {
261 return emitOpError(
"Only RN rounding mode is supported for "
262 "conversions from f32x2 to ")
263 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
264 << mlir::Float8E5M2Type::get(ctx) <<
" types";
267 return emitOpError(
"Only SATFINITE saturation mode is supported "
270 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
271 << mlir::Float8E5M2Type::get(ctx) <<
" types";
275 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
276 if (!(isRoundingModeRZ || isRoundingModeRP)) {
277 return emitOpError(
"Only RZ and RP rounding modes are supported for "
278 "conversions from f32x2 to ")
279 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
282 return emitOpError(
"relu not supported for conversions to ")
283 << mlir::Float8E8M0FNUType::get(ctx) <<
" type";
289 << mlir::Float8E4M3FNType::get(ctx) <<
", "
290 << mlir::Float8E5M2Type::get(ctx) <<
", and "
291 << mlir::Float8E8M0FNUType::get(ctx)
293 "supported for conversions from f32x2 to f8x2";
297LogicalResult ConvertF16x2ToF8x2Op::verify() {
300 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
302 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
303 << mlir::Float8E5M2Type::get(ctx)
304 <<
" types are supported for conversions from f16x2 to f8x2.";
309LogicalResult ConvertBF16x2ToF8x2Op::verify() {
310 using RndMode = NVVM::FPRoundingMode;
312 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
314 <<
" type is supported for conversions from "
318 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
319 return emitOpError(
"Only RZ and RP rounding modes are supported for "
320 "conversions from bf16x2 to f8x2.");
325LogicalResult ConvertF32x2ToF4x2Op::verify() {
328 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
330 << mlir::Float4E2M1FNType::get(ctx)
331 <<
" type is supported for conversions from f32x2 to f4x2.";
336LogicalResult ConvertF8x2ToF16x2Op::verify() {
339 if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
341 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
342 << mlir::Float8E5M2Type::get(ctx)
343 <<
" types are supported for conversions from f8x2 to f16x2.";
348LogicalResult ConvertF8x2ToBF16x2Op::verify() {
350 if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
352 << mlir::Float8E8M0FNUType::get(ctx)
353 <<
" type is supported for conversions from f8x2 to bf16x2.";
358LogicalResult ConvertF6x2ToF16x2Op::verify() {
361 if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
363 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
364 << mlir::Float6E3M2FNType::get(ctx)
365 <<
" types are supported for conversions from f6x2 to f16x2.";
370LogicalResult ConvertF4x2ToF16x2Op::verify() {
373 if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
375 << mlir::Float4E2M1FNType::get(ctx)
376 <<
" type is supported for conversions from f4x2 to f16x2.";
385LogicalResult ConvertF32x2ToF16x2Op::verify() {
386 if (getRnd() != FPRoundingMode::RS)
387 return emitOpError(
"Only RS rounding mode is supported for "
388 "conversions from f32x2 to f16x2.");
392LogicalResult ConvertF32x2ToBF16x2Op::verify() {
393 if (getRnd() != FPRoundingMode::RS)
394 return emitOpError(
"Only RS rounding mode is supported for "
395 "conversions from f32x2 to bf16x2.");
399LogicalResult ConvertF32x4ToF8x4Op::verify() {
402 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy()))
404 << mlir::Float8E4M3FNType::get(ctx) <<
" and "
405 << mlir::Float8E5M2Type::get(ctx)
406 <<
" types are supported for conversions from f32x4 to f8x4.";
411LogicalResult ConvertF32x4ToF6x4Op::verify() {
414 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy()))
416 << mlir::Float6E2M3FNType::get(ctx) <<
" and "
417 << mlir::Float6E3M2FNType::get(ctx)
418 <<
" types are supported for conversions from f32x4 to f6x4.";
423LogicalResult ConvertF32x4ToF4x4Op::verify() {
426 if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
427 return emitOpError(
"Only ") << mlir::Float4E2M1FNType::get(ctx)
428 <<
" type is supported for conversions from "
434LogicalResult BulkStoreOp::verify() {
435 if (getInitVal() != 0)
436 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
440LogicalResult PMEventOp::verify() {
441 auto eventId = getEventId();
442 auto maskedEventId = getMaskedEventId();
443 if (!maskedEventId && !eventId) {
444 return emitOpError() <<
"either `id` or `mask` must be set";
447 if (maskedEventId && eventId) {
448 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
452 if (eventId < 0 || eventId > 15) {
453 return emitOpError() <<
"`id` must be between 0 and 15";
457 return llvm::success();
463std::optional<mlir::NVVM::MMATypes>
464MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
466 VectorType::get(2, Float16Type::get(operandElType.
getContext()));
467 if (operandElType.
isF64())
468 return NVVM::MMATypes::f64;
469 if (operandElType.
isF16() || operandElType == half2Type)
470 return NVVM::MMATypes::f16;
471 if (operandElType.
isF32() && isAccumulator)
472 return NVVM::MMATypes::f32;
473 if (operandElType.
isF32() && !isAccumulator)
474 return NVVM::MMATypes::tf32;
475 if (llvm::isa<IntegerType>(operandElType)) {
477 return NVVM::MMATypes::s32;
481 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
482 if (structType.getBody().empty())
484 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
491 return (type == MMATypes::u4 || type == MMATypes::s4);
495 return (type == MMATypes::u8 || type == MMATypes::s8);
500 type == MMATypes::s32;
503MMATypes MmaOp::accumPtxType() {
504 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
505 getODSOperands(2).getTypes().front(),
true);
506 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
510MMATypes MmaOp::resultPtxType() {
511 std::optional<mlir::NVVM::MMATypes> val =
512 inferOperandMMAType(getResult().
getType(),
true);
513 assert(val.has_value() &&
"result PTX type should always be inferrable");
519 struct OperandFragment {
520 StringRef operandName;
521 StringRef ptxTypeAttr;
522 SmallVector<Value, 4> regs;
523 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
524 : operandName(name), ptxTypeAttr(ptxTypeName) {}
527 std::array<OperandFragment, 3> frags{
528 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
529 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
530 OperandFragment(
"C",
"")};
532 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
534 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
535 auto &frag = frags[fragIdx];
536 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
537 for (
auto operandIdx = varOperandSpec.first;
538 operandIdx < varOperandSpec.first + varOperandSpec.second;
540 frag.regs.push_back(this->getOperand(operandIdx));
541 if (operandIdx == 0) {
542 regTypes.push_back(this->getOperand(operandIdx).
getType());
545 std::optional<MMATypes> inferredType =
546 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
548 ignoreAttrNames.push_back(frag.ptxTypeAttr);
551 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
552 p <<
" " << frag.operandName;
558 for (
const auto &frag : frags) {
559 printMmaOperand(frag);
568 frags[1].regs[0].getType(),
569 frags[2].regs[0].getType()},
578 std::optional<MMAIntOverflow> intOverflow,
579 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
580 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
582 assert(
shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
587 result.addOperands(operandA);
588 result.addOperands(operandB);
589 result.addOperands(operandC);
591 if (multiplicandPtxTypes) {
592 result.addAttribute(
"multiplicandAPtxType",
593 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
594 result.addAttribute(
"multiplicandBPtxType",
595 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
597 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
598 result.addAttribute(
"multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
599 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
600 result.addAttribute(
"multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
603 if (multiplicandLayouts) {
604 result.addAttribute(
"layoutA",
605 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
606 result.addAttribute(
"layoutB",
607 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
609 result.addAttribute(
"layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
610 result.addAttribute(
"layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
613 if (intOverflow.has_value())
614 result.addAttribute(
"intOverflowBehavior",
615 MMAIntOverflowAttr::get(ctx, *intOverflow));
616 if (b1Op.has_value())
617 result.addAttribute(
"b1Op", MMAB1OpAttr::get(ctx, *b1Op));
619 result.addTypes(resultType);
621 MmaOp::getOperandSegmentSizeAttr(),
623 static_cast<int32_t>(operandB.size()),
624 static_cast<int32_t>(operandC.size())}));
632 struct OperandFragment {
633 std::optional<MMATypes> elemtype;
634 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
635 SmallVector<Type> regTypes;
639 std::array<OperandFragment, 4> frags;
644 auto parseMmaOperand = [&](StringRef operandName,
645 OperandFragment &frag) -> LogicalResult {
656 if (parseMmaOperand(
"A", frags[0]).
failed())
658 if (parseMmaOperand(
"B", frags[1]).
failed())
660 if (parseMmaOperand(
"C", frags[2]).
failed())
675 if (operandTypes.size() != 3)
678 "expected one type for each operand segment but got " +
679 Twine(operandTypes.size()) +
" types");
680 for (
const auto &iter : llvm::enumerate(operandTypes)) {
681 auto &frag = frags[iter.index()];
682 frag.regTypes.resize(frag.regs.size(), iter.value());
686 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
693 frags[3].elemtype = inferOperandMMAType(resultType,
true);
695 std::array<StringRef, 2> names{
"multiplicandAPtxType",
696 "multiplicandBPtxType"};
697 for (
unsigned idx = 0; idx < names.size(); idx++) {
698 const auto &frag = frags[idx];
699 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
700 if (!frag.elemtype.has_value() && !attr.has_value()) {
703 "attribute " + names[idx] +
704 " is not provided explicitly and cannot be inferred");
706 if (!attr.has_value())
708 names[idx], MMATypesAttr::get(parser.
getContext(), *frag.elemtype));
711 result.addTypes(resultType);
712 if (!namedAttributes.
empty())
713 result.addAttributes(namedAttributes);
714 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
716 static_cast<int32_t>(frags[0].regs.size()),
717 static_cast<int32_t>(frags[1].regs.size()),
718 static_cast<int32_t>(frags[2].regs.size()),
723LogicalResult MmaOp::verify() {
725 auto f16Ty = Float16Type::get(context);
726 auto i32Ty = IntegerType::get(context, 32);
727 auto f16x2Ty = VectorType::get(2, f16Ty);
728 auto f32Ty = Float32Type::get(context);
729 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
730 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
733 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
736 auto f16x2x2StructTy =
737 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
739 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
741 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
743 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
744 getShapeAttr().getK()};
750 AllowedShapes allowedShapes;
751 AllowedTypes expectedA;
752 AllowedTypes expectedB;
753 AllowedTypes expectedC;
758 if (mmaShape[0] == 16) {
760 Type multiplicandFragType;
761 switch (*getMultiplicandAPtxType()) {
764 multiplicandFragType = i32Ty;
765 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
766 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
770 multiplicandFragType = i32Ty;
771 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
772 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
776 multiplicandFragType = f16x2Ty;
777 expectedResult.push_back(f16x2x2StructTy);
778 expectedResult.push_back(f32x4StructTy);
792 return emitError(
"invalid shape or multiplicand type: " +
793 stringifyEnum(getMultiplicandAPtxType().value()));
797 expectedResult.push_back(s32x4StructTy);
798 expectedC.emplace_back(4, i32Ty);
799 multiplicandFragType = i32Ty;
801 expectedC.emplace_back(2, f16x2Ty);
802 expectedC.emplace_back(4, f32Ty);
805 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
806 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
807 expectedA.emplace_back(unitA, multiplicandFragType);
808 expectedB.emplace_back(unitB, multiplicandFragType);
809 allowedShapes.push_back({16, 8, kFactor});
810 allowedShapes.push_back({16, 8, kFactor * 2});
812 if (resultPtxType() != accumPtxType())
817 if (mmaShape[0] == 8) {
818 if (*getMultiplicandAPtxType() == MMATypes::f16) {
819 expectedA.emplace_back(2, f16x2Ty);
820 expectedB.emplace_back(2, f16x2Ty);
821 expectedResult.push_back(f16x2x4StructTy);
822 expectedResult.push_back(f32x8StructTy);
823 expectedC.emplace_back(4, f16x2Ty);
824 expectedC.emplace_back(8, f32Ty);
825 allowedShapes.push_back({8, 8, 4});
827 if (*getMultiplicandAPtxType() == MMATypes::f64) {
828 Type f64Ty = Float64Type::get(context);
829 expectedA.emplace_back(1, f64Ty);
830 expectedB.emplace_back(1, f64Ty);
831 expectedC.emplace_back(2, f64Ty);
832 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
834 allowedShapes.push_back({8, 8, 4});
837 expectedA.push_back({i32Ty});
838 expectedB.push_back({i32Ty});
839 expectedC.push_back({i32Ty, i32Ty});
840 expectedResult.push_back(s32x2StructTy);
842 allowedShapes.push_back({8, 8, 32});
844 allowedShapes.push_back({8, 8, 16});
845 if (getMultiplicandAPtxType().value() == MMATypes::b1)
846 allowedShapes.push_back({8, 8, 128});
850 std::string errorMessage;
851 llvm::raw_string_ostream errorStream(errorMessage);
854 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
855 !llvm::is_contained(allowedShapes, mmaShape)) {
856 errorStream <<
"unimplemented variant for MMA shape <";
857 llvm::interleaveComma(mmaShape, errorStream);
863 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
864 for (
const auto &iter : llvm::enumerate(
866 auto spec = this->getODSOperandIndexAndLength(iter.index());
868 operand_type_begin() + spec.first +
870 bool match = llvm::is_contained(iter.value(), operandTySeg);
873 errorStream <<
"Could not match types for the "
874 << operandNames[iter.index()]
875 <<
" operands; expected one of ";
876 for (
const auto &x : iter.value()) {
877 errorStream << x.size() <<
"x" << x[0] <<
" ";
879 errorStream <<
"but got ";
880 llvm::interleaveComma(operandTySeg, errorStream);
886 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
887 return expectedResultType == getResult().getType();
890 <<
"Could not match allowed types for the result; expected one of ";
891 llvm::interleaveComma(expectedResult, errorStream);
892 errorStream <<
" but got " << getResult().getType();
897 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
898 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
906 if (!getIntOverflowBehavior())
908 getIntOverflowBehaviorAttrName().strref() +
916 (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
917 getMultiplicandAPtxType() == MMATypes::f16);
921 if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
922 return emitOpError(
"requires layoutA = #nvvm.mma_layout<row> and "
923 "layoutB = #nvvm.mma_layout<col> for shape <")
924 << mmaShape[0] <<
", " << mmaShape[1] <<
", " << mmaShape[2]
925 <<
"> with element types "
926 << stringifyEnum(*getMultiplicandAPtxType()) <<
" and "
927 << stringifyEnum(*getMultiplicandBPtxType())
928 <<
". Only m8n8k4 with f16 supports other layouts.";
935LogicalResult ShflOp::verify() {
936 auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
938 auto verifyTypeError = [&](Twine desc,
Type expectedType,
939 Type actualType) -> LogicalResult {
940 return emitOpError(
"expected " + desc +
" to be of type ")
941 << expectedType <<
" but got " << actualType <<
" instead";
944 if (returnStructType) {
945 if (!getReturnValueAndIsValid())
946 return emitOpError(
"\"return_value_and_is_valid\" attribute must be "
947 "specified when the return type is a struct type");
949 if (returnStructType.getBody().size() != 2)
950 return emitOpError(
"expected return type to be a two-element struct");
953 auto resultType = returnStruct[0];
954 if (resultType != getVal().
getType())
955 return verifyTypeError(
"first element in the returned struct",
956 getVal().
getType(), resultType);
958 auto predicateType = returnStruct[1];
959 if (!predicateType.isInteger(1))
960 return verifyTypeError(
"second element in the returned struct",
964 if (getReturnValueAndIsValid())
965 return emitOpError(
"expected return type to be a two-element struct");
968 return verifyTypeError(
"return type", getVal().
getType(),
getType());
974 NVVM::MMAFrag frag,
int nRow,
977 unsigned numberElements = 0;
981 if (type == NVVM::MMATypes::f16) {
983 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
987 }
else if (type == NVVM::MMATypes::f32) {
990 }
else if (type == NVVM::MMATypes::f64) {
992 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
996 }
else if (type == NVVM::MMATypes::tf32) {
999 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
1001 int parallelSize = 0;
1002 if (frag == NVVM::MMAFrag::a)
1003 parallelSize = nRow;
1004 if (frag == NVVM::MMAFrag::b)
1005 parallelSize = nCol;
1008 if (parallelSize == 16)
1011 else if (parallelSize == 8)
1013 else if (parallelSize == 32)
1015 }
else if (type == NVVM::MMATypes::s32) {
1019 assert(numberElements != 0 && elementType !=
nullptr);
1020 return std::make_pair(elementType, numberElements);
1023static std::pair<mlir::Type, unsigned>
1027 if (frag == NVVM::MMAFrag::a) {
1030 }
else if (frag == NVVM::MMAFrag::b) {
1037 assert(nRow && nCol);
1041LogicalResult NVVM::WMMALoadOp::verify() {
1042 unsigned addressSpace =
1043 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
1044 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1045 addressSpace != NVVMMemorySpace::Shared)
1046 return emitOpError(
"expected source pointer in memory "
1049 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
1050 getEltype(), getFrag()) == 0)
1051 return emitOpError() <<
"invalid attribute combination";
1056 if (typeInfo.first == f64Ty && typeInfo.second == 1) {
1058 return emitOpError(
"expected destination type to be f64");
1062 Type dstType = LLVM::LLVMStructType::getLiteral(
1065 return emitOpError(
"expected destination type is a structure of ")
1066 << typeInfo.second <<
" elements of type " << typeInfo.first;
1070LogicalResult NVVM::WMMAStoreOp::verify() {
1071 unsigned addressSpace =
1072 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
1073 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
1074 addressSpace != NVVMMemorySpace::Shared)
1075 return emitOpError(
"expected operands to be a source pointer in memory "
1078 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
1080 return emitOpError() <<
"invalid attribute combination";
1083 if (getArgs().size() != typeInfo.second)
1084 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
1085 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
1086 return operands.
getType() != typeInfo.first;
1088 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
1092LogicalResult NVVM::WMMAMmaOp::verify() {
1093 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
1094 getLayoutB(), getEltypeA(),
1096 return emitOpError() <<
"invalid attribute combination";
1104 arguments.append(typeInfoA.second, typeInfoA.first);
1105 arguments.append(typeInfoB.second, typeInfoB.first);
1106 arguments.append(typeInfoC.second, typeInfoC.first);
1107 unsigned numArgs = arguments.size();
1108 if (getArgs().size() != numArgs)
1109 return emitOpError() <<
"expected " << numArgs <<
" arguments";
1110 for (
unsigned i = 0; i < numArgs; i++) {
1111 if (getArgs()[i].
getType() != arguments[i])
1112 return emitOpError() <<
"expected argument " << i <<
" to be of type "
1115 Type dstType = LLVM::LLVMStructType::getLiteral(
1118 return emitOpError(
"expected destination type is a structure of ")
1119 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
1123LogicalResult NVVM::LdMatrixOp::verify() {
1125 if (m == 8 && n == 8) {
1126 if (num != 1 && num != 2 && num != 4) {
1127 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
1130 if (getEltType() != LdStMatrixEltType::B16) {
1131 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
1133 }
else if (m == 8 && n == 16) {
1134 if (num != 1 && num != 2 && num != 4) {
1135 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
1138 if (getLayout() != MMALayout::row) {
1139 return emitOpError(
"expected layout to be row for 8x16 matrix");
1141 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1142 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1143 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
1144 "b8x16.b6x16_p32 for 8x16 matrix");
1146 }
else if (m == 16 && n == 16) {
1147 if (num != 1 && num != 2) {
1148 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
1151 if (getLayout() != MMALayout::col) {
1152 return emitOpError(
"expected layout to be col for 16x16 matrix");
1154 if (getEltType() != LdStMatrixEltType::B8 &&
1155 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
1156 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
1157 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
1158 "b8x16.b6x16_p32 for 16x16 matrix");
1161 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
1165 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
1166 if (numElements == 1 &&
getType() != i32)
1167 return emitOpError(
"expected destination type is i32");
1168 if (numElements == 2 || numElements == 4) {
1169 Type dstType = LLVM::LLVMStructType::getLiteral(
1172 return emitOpError(
"expected destination type is a structure of ")
1173 << numElements <<
" elements of type i32";
1179LogicalResult NVVM::StMatrixOp::verify() {
1180 int numMatrix = getSources().size();
1181 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1182 return emitOpError(
"expected num attribute to be 1, 2 or 4");
1185 if (m == 8 && n == 8) {
1186 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1187 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
1189 }
else if (m == 16 && n == 8) {
1190 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1191 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
1193 if (getLayout() != NVVM::MMALayout::col) {
1194 return emitOpError(
"expected layout to be col for 16x8 matrix");
1197 return emitOpError(
"expected shape to be 8x8 or 16x8");
1204 if (typeA == NVVM::WGMMATypes::tf32)
1206 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1208 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1210 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1212 if (typeA == NVVM::WGMMATypes::b1)
1218 NVVM::WGMMATypes typeA,
1219 NVVM::WGMMATypes typeB) {
1221 case NVVM::WGMMATypes::f16:
1222 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1223 typeB == NVVM::WGMMATypes::f16)
1226 case NVVM::WGMMATypes::tf32:
1227 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1230 case NVVM::WGMMATypes::u8:
1231 case NVVM::WGMMATypes::s8:
1232 if (typeD == NVVM::WGMMATypes::s32 &&
1233 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1236 case NVVM::WGMMATypes::b1:
1237 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1240 case NVVM::WGMMATypes::bf16:
1241 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1242 typeB == NVVM::WGMMATypes::bf16)
1245 case NVVM::WGMMATypes::e4m3:
1246 case NVVM::WGMMATypes::e5m2:
1247 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1248 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1251 case WGMMATypes::f32:
1252 case WGMMATypes::s32:
1253 llvm_unreachable(
"unsupported input types");
1261 72, 80, 88, 96, 104, 112, 120, 128,
1262 136, 144, 152, 160, 168, 176, 184, 192,
1263 200, 208, 216, 224, 232, 240, 248, 256};
1265 80, 96, 112, 128, 144, 160,
1266 176, 192, 208, 224, 240, 256};
1268 case WGMMATypes::f16:
1269 case WGMMATypes::tf32:
1270 case WGMMATypes::bf16:
1271 case WGMMATypes::e4m3:
1272 case WGMMATypes::e5m2:
1273 if (llvm::is_contained(allowedN, sizeN))
1276 case WGMMATypes::u8:
1277 case WGMMATypes::s8:
1278 case WGMMATypes::b1:
1279 if (llvm::is_contained(allowedNshort, sizeN))
1282 case WGMMATypes::f32:
1283 case WGMMATypes::s32:
1284 llvm_unreachable(
"unsupported input types");
1290LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
1291 Value outValue = getResults();
1292 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
1294 return emitOpError() <<
"expected results to be struct";
1295 int outputSize = stype.getBody().size();
1296 WGMMATypes typeD = getTypeD();
1297 WGMMATypes typeA = getTypeA();
1298 WGMMATypes typeB = getTypeB();
1300 for (
Type t : stype.getBody()) {
1301 if (t != stype.getBody().front())
1303 <<
"all elements in struct must be same type but there is " << t;
1306 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1307 typeD != WGMMATypes::s32) {
1308 return emitOpError() <<
"does not support the given output type "
1309 << NVVM::stringifyWGMMATypes(typeD);
1311 if (typeD == WGMMATypes::s32 &&
1312 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1313 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
1317 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1318 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
1319 << NVVM::stringifyWGMMATypes(typeB)
1320 <<
", it is not supported.";
1330 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
1331 <<
" for input type "
1332 << NVVM::stringifyWGMMATypes(typeA);
1337 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
1338 <<
getShape().getN() <<
", it is not supported.";
1345 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1346 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1347 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1349 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1350 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
1351 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
1352 << stringifyWGMMATypes(typeB)
1353 <<
" requires transpose. However, this is only supported for: "
1354 << stringifyMMATypes(MMATypes::f16) <<
" and "
1355 << stringifyMMATypes(MMATypes::bf16);
1359 int expectedOutput = 0;
1360 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1361 expectedOutput =
getShape().getN() / 2;
1362 if (typeD == WGMMATypes::f16)
1363 expectedOutput =
getShape().getN() / 4;
1364 if (outputSize != expectedOutput) {
1365 return emitOpError() <<
"results " << expectedOutput
1366 <<
", however output struct has " << outputSize
1370 if (typeD != WGMMATypes::s32 &&
1371 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1372 NVVM::MMAIntOverflow::satfinite) {
1374 <<
" `satfinite` can be only used with s32 accumulator, however "
1375 "the current accumulator is "
1376 << NVVM::stringifyWGMMATypes(typeD);
1382std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1385 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1387 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1389 int expectedOutputRegisters = 0;
1390 if (getTypeD() == WGMMATypes::f16)
1391 expectedOutputRegisters =
getShape().getN() / 4;
1393 expectedOutputRegisters =
getShape().getN() / 2;
1396 llvm::raw_string_ostream ss(ptx);
1401 << ((expectedOutputRegisters * 2) + 2)
1403 "wgmma.mma_async.sync.aligned.m"
1404 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1405 << stringifyWGMMATypes(getTypeA()) <<
"."
1406 << stringifyWGMMATypes(getTypeB());
1407 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1408 NVVM::MMAIntOverflow::satfinite)
1412 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1413 ss <<
"$" << regCnt;
1414 if (regCnt != expectedOutputRegisters - 1)
1420 regCnt = (regCnt * 2);
1421 ss <<
" $" << (regCnt) <<
","
1422 <<
" $" << (regCnt + 1) <<
","
1424 if (getTypeD() != WGMMATypes::s32) {
1425 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1429 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1436bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1440 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1447 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1449 if (getTypeD() != WGMMATypes::s32) {
1450 asmValues.push_back(
1451 {makeConstantI32(rewriter,
1452 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1454 asmValues.push_back(
1455 {makeConstantI32(rewriter,
1456 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1460 asmValues.push_back(
1461 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1463 asmValues.push_back(
1464 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1470LogicalResult NVVM::FenceProxyOp::verify() {
1471 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1472 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1473 if (getKind() == NVVM::ProxyKind::GENERIC)
1474 return emitOpError() <<
"generic proxy not a supported proxy kind";
1475 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1476 return emitOpError() <<
"async_shared fence requires space attribute";
1478 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1479 return emitOpError() <<
"only async_shared fence can have space attribute";
1484LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1485 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1486 return emitOpError(
"uni-directional proxies only support generic for "
1487 "from_proxy attribute");
1489 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1490 return emitOpError(
"uni-directional proxies only support tensormap "
1491 "for to_proxy attribute");
1496LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1497 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1498 return emitOpError(
"uni-directional proxies only support generic for "
1499 "from_proxy attribute");
1501 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1502 return emitOpError(
"uni-directional proxies only support tensormap "
1503 "for to_proxy attribute");
1508LogicalResult NVVM::SetMaxRegisterOp::verify() {
1509 if (getRegCount() % 8)
1510 return emitOpError(
"new register size must be multiple of 8");
1511 if (getRegCount() < 24 || getRegCount() > 256)
1512 return emitOpError(
"new register size must be in between 24 to 256");
1516LogicalResult NVVM::BarrierOp::verify() {
1517 if (getNumberOfThreads() && !getBarrierId())
1519 "barrier id is missing, it should be set between 0 to 15");
1521 if (getBarrierId() && (
getReductionOp() || getReductionPredicate()))
1522 return emitOpError(
"reduction are only available when id is 0");
1526 return emitOpError(
"reduction predicate and reduction operation must be "
1527 "specified together");
1532LogicalResult NVVM::Tcgen05CpOp::verify() {
1533 auto mc = getMulticast();
1535 using SH = Tcgen05CpShape;
1536 using MC = Tcgen05CpMulticast;
1538 case SH::SHAPE_128x256b:
1539 case SH::SHAPE_128x128b:
1540 case SH::SHAPE_4x256b:
1542 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1544 case SH::SHAPE_64x128b:
1545 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1546 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1547 "warpx2_02_13 for tcgen05.cp Op");
1549 case SH::SHAPE_32x128b:
1550 if (mc != MC::WARPX4)
1552 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1558LogicalResult NVVM::MatchSyncOp::verify() {
1559 if (getKind() == NVVM::MatchSyncKind::all) {
1560 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1561 if (!type || type.getBody().size() != 2 ||
1562 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1563 return emitOpError(
"match.sync 'all' returns a two element struct with "
1564 "first element as i32 and second element as i1");
1567 if (!
getType().isInteger(32)) {
1568 return emitOpError(
"match.sync 'any' returns an i32");
1574LogicalResult NVVM::VoteSyncOp::verify() {
1575 if (getKind() == NVVM::VoteSyncKind::ballot) {
1576 if (!
getType().isInteger(32)) {
1577 return emitOpError(
"vote.sync 'ballot' returns an i32");
1580 if (!
getType().isInteger(1)) {
1581 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1587LogicalResult NVVM::PrefetchOp::verify() {
1588 using MemSpace = NVVM::NVVMMemorySpace;
1589 using CacheLevel = NVVM::PrefetchCacheLevel;
1591 unsigned addressSpace =
1592 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1593 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1594 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1596 if (getTensormap() && cacheLevel)
1597 return emitOpError(
"cannot specify both tensormap and cache level");
1599 if (getTensormap()) {
1600 if (addressSpace != MemSpace::Generic &&
1601 addressSpace != MemSpace::Constant) {
1603 "prefetch tensormap requires a generic or constant pointer");
1606 if (evictPriority) {
1608 "prefetch tensormap does not support eviction priority");
1611 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1613 "in_param_space can only be specified for a generic pointer");
1616 }
else if (cacheLevel) {
1617 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1618 addressSpace != MemSpace::Local) {
1619 return emitOpError(
"prefetch to cache level requires a generic, global, "
1620 "or local pointer");
1624 if (*cacheLevel != CacheLevel::L1) {
1626 "unsupported cache level, the only supported uniform "
1627 "cache level is L1");
1630 if (addressSpace != MemSpace::Generic) {
1632 "prefetch to uniform cache requires a generic pointer");
1636 if (evictPriority) {
1637 if (*cacheLevel != CacheLevel::L2)
1639 "cache eviction priority supported only for cache level L2");
1641 if (addressSpace != MemSpace::Global)
1642 return emitOpError(
"cache eviction priority requires a global pointer");
1644 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1645 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1647 "unsupported cache eviction priority, only evict_last and "
1648 "evict_normal are supported");
1652 return emitOpError(
"predicate supported only on prefetch tensormap");
1656 "requires specification of either cache level or tensormap");
1662LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
1663 switch (getQueryType()) {
1664 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1666 return emitOpError(
"is_canceled query type returns an i1");
1668 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1669 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1670 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1671 if (!
getType().isInteger(32)) {
1672 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
1673 "get_first_cta_id_z query types return an i32");
1680LogicalResult NVVM::ReduxOp::verify() {
1683 if (!reduxType.
isF32()) {
1685 return emitOpError(
"abs attribute is supported only for f32 type");
1687 return emitOpError(
"nan attribute is supported only for f32 type");
1690 NVVM::ReduxKind kind = getKind();
1692 case NVVM::ReduxKind::ADD:
1693 case NVVM::ReduxKind::AND:
1694 case NVVM::ReduxKind::OR:
1695 case NVVM::ReduxKind::XOR:
1696 case NVVM::ReduxKind::MAX:
1697 case NVVM::ReduxKind::MIN:
1698 case NVVM::ReduxKind::UMAX:
1699 case NVVM::ReduxKind::UMIN:
1702 << stringifyEnum(kind) <<
"' redux kind unsupported with "
1703 << reduxType <<
" type. Only supported type is 'i32'.";
1705 case NVVM::ReduxKind::FMIN:
1706 case NVVM::ReduxKind::FMAX:
1707 if (!reduxType.
isF32())
1709 << stringifyEnum(kind) <<
"' redux kind unsupported with "
1710 << reduxType <<
" type. Only supported type is 'f32'.";
1723 unsigned sizeInBits,
1725 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1727 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1728 if (mask != 0xffffffffu)
1729 field = builder.CreateAnd(field, builder.getInt32(mask));
1731 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1732 field = builder.CreateShl(field, start);
1734 return builder.CreateOr(
result, field);
1737void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1739 llvm::IRBuilderBase &builder) {
1740 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1741 llvm::Value *smemDesc = builder.getInt64(0);
1746 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1748 builder, smemDesc, mt.
lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1754 builder, smemDesc, mt.
lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1758 mt.
mapValue(thisOp.getRes()) = smemDesc;
1765std::string NVVM::MBarrierInitOp::getPtx() {
1767 return isShared ? std::string(
"mbarrier.init.shared.b64 [%0], %1;")
1768 : std::string(
"mbarrier.init.b64 [%0], %1;");
1771std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1774 ? std::string(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1775 : std::string(
"mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
1778std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1780 llvm::StringRef space = isShared ?
".shared" :
"";
1782 return llvm::formatv(
"{\n\t"
1783 ".reg .pred P1; \n\t"
1785 "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1786 "@P1 bra.uni DONE; \n\t"
1787 "bra.uni LAB_WAIT; \n\t"
1799 auto thisOp = cast<NVVM::BarrierOp>(op);
1800 llvm::Value *barrierId = thisOp.getBarrierId()
1802 : builder.getInt32(0);
1803 llvm::Intrinsic::ID id;
1805 if (thisOp.getNumberOfThreads()) {
1806 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1807 args.push_back(barrierId);
1808 args.push_back(mt.
lookupValue(thisOp.getNumberOfThreads()));
1809 }
else if (thisOp.getReductionOp()) {
1810 switch (*thisOp.getReductionOp()) {
1811 case NVVM::BarrierReduction::AND:
1812 id = llvm::Intrinsic::nvvm_barrier0_and;
1814 case NVVM::BarrierReduction::OR:
1815 id = llvm::Intrinsic::nvvm_barrier0_or;
1817 case NVVM::BarrierReduction::POPC:
1818 id = llvm::Intrinsic::nvvm_barrier0_popc;
1821 args.push_back(mt.
lookupValue(thisOp.getReductionPredicate()));
1823 id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1824 args.push_back(barrierId);
1827 return {id, std::move(args)};
1832 auto thisOp = cast<NVVM::MBarrierInitOp>(op);
1834 llvm::Intrinsic::ID
id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared
1835 : llvm::Intrinsic::nvvm_mbarrier_init;
1840 args.push_back(mt.
lookupValue(thisOp.getCount()));
1842 return {id, std::move(args)};
1847 auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
1849 llvm::Intrinsic::ID
id = isShared
1850 ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
1851 : llvm::Intrinsic::nvvm_mbarrier_inval;
1858 auto thisOp = cast<NVVM::MBarrierArriveOp>(op);
1860 llvm::Intrinsic::ID
id = isShared
1861 ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared
1862 : llvm::Intrinsic::nvvm_mbarrier_arrive;
1869 auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op);
1871 llvm::Intrinsic::ID
id =
1872 isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared
1873 : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete;
1877 args.push_back(mt.
lookupValue(thisOp.getCount()));
1879 return {id, std::move(args)};
1884 auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
1886 llvm::Intrinsic::ID
id = isShared
1887 ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
1888 : llvm::Intrinsic::nvvm_mbarrier_test_wait;
1892 args.push_back(mt.
lookupValue(thisOp.getState()));
1894 return {id, std::move(args)};
1899 auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op);
1902 llvm::Intrinsic::ID id;
1903 if (thisOp.getNoinc()) {
1904 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared
1905 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc;
1907 id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared
1908 : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive;
1914#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1915 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1917#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1918 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1923 llvm::Intrinsic::ID id;
1925 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1926 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1927 switch (cpAsyncOp.getSize()) {
1935 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1940 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1944 args.push_back(mt.
lookupValue(cpAsyncOp.getDst()));
1945 args.push_back(mt.
lookupValue(cpAsyncOp.getSrc()));
1947 args.push_back(mt.
lookupValue(cpAsyncOp.getCpSize()));
1954 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1956 llvm::Intrinsic::ID
id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1959 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1963 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1964 llvm::Value *i64Unused =
1965 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
1966 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1967 args.push_back(builder.getInt1(hasCacheHint));
1969 return {id, std::move(args)};
1974 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1978 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
1980 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
1984 mlir::Value multicastMask = thisOp.getMulticastMask();
1985 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
1986 llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1987 args.push_back(hasMulticastMask ? mt.
lookupValue(multicastMask) : i16Unused);
1991 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1992 llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
1993 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
1996 args.push_back(builder.getInt1(hasMulticastMask));
1997 args.push_back(builder.getInt1(hasCacheHint));
1999 llvm::Intrinsic::ID
id =
2000 llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
2002 return {id, std::move(args)};
2007 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
2009 llvm::Intrinsic::ID
id =
2010 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
2013 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
2014 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2018 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2019 llvm::Value *i64Unused =
2020 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2021 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2022 args.push_back(builder.getInt1(hasCacheHint));
2025 if (
mlir::Value byteMask = thisOp.getByteMask()) {
2027 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
2030 return {id, std::move(args)};
2033bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
2040 for (
auto val : getOperands())
2047CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2049 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
2050 const bool isCTAOnly = thisOp.getIsCTAOnly();
2054 args.push_back(mt.
lookupValue(thisOp.getDstMem()));
2056 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2066 const bool hasMC =
static_cast<bool>(mcMask);
2067 llvm::Value *i16Zero =
2068 llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.
getLLVMContext()), 0);
2072 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2073 llvm::Value *i64Zero =
2074 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2080 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
2082 llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.
getLLVMContext()), val);
2086 args.push_back(hasMC ? mt.
lookupValue(mcMask) : i16Zero);
2087 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
2088 args.push_back(builder.getInt1(hasMC));
2089 args.push_back(builder.getInt1(hasCacheHint));
2093 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Zero);
2094 args.push_back(builder.getInt1(hasCacheHint));
2097 constexpr size_t numDims = 5;
2098 constexpr size_t numModes = 5;
2099 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
2100 using TableTy = std::array<rowTy, numModes>;
2101 static constexpr TableTy IDTable{
2102 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
2103 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
2104 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
2105 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
2106 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
2108 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
2109 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
2110 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
2112 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
2113 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
2114 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
2116 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
2117 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
2118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
2120 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
2122 static constexpr TableTy IDTableCTA{
2124 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
2125 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
2126 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
2127 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
2128 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
2130 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
2131 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
2132 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
2134 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
2135 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
2136 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
2138 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
2139 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
2140 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
2142 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
2145 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
2146 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
2147 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
2148 size_t mode =
static_cast<size_t>(thisOp.getMode());
2149 size_t dim = thisOp.getCoordinates().size();
2150 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
2152 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
2154 return {id, std::move(args)};
2159 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
2163 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2165 for (
auto v : thisOp.getCoordinates())
2167 for (
auto v : thisOp.getIm2colOffsets())
2171 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2172 llvm::Value *i64Unused =
2173 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2174 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2175 args.push_back(builder.getInt1(hasCacheHint));
2177 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2178 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2179 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
2180 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
2181 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
2182 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
2183 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
2185 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
2186 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
2187 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
2189 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
2190 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
2191 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
2193 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
2194 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
2195 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
2196 {NI, NI, NI, NI, NI,
2197 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
2199 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
2200 "TMALoadModes must match number of rows in IDTable");
2201 size_t mode =
static_cast<size_t>(thisOp.getMode());
2202 size_t dim = thisOp.getCoordinates().size();
2203 llvm::Intrinsic::ID
id = IDTable[mode][dim];
2204 if (
id == llvm::Intrinsic::not_intrinsic)
2205 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
2207 return {id, std::move(args)};
2211CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2213 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
2217 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2218 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2220 for (
auto v : thisOp.getCoordinates())
2224 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2225 llvm::Value *i64Unused =
2226 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.
getLLVMContext()), 0);
2227 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64Unused);
2228 args.push_back(builder.getInt1(hasCacheHint));
2230 const unsigned NI = llvm::Intrinsic::not_intrinsic;
2231 static constexpr llvm::Intrinsic::ID IDTable[][6] = {
2232 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
2233 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
2234 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
2235 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
2236 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
2237 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
2238 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
2239 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
2240 {NI, NI, NI, NI, NI,
2241 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
2243 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
2244 "TMAStoreModes must match number of rows in IDTable");
2245 size_t mode =
static_cast<size_t>(thisOp.getMode());
2246 size_t dim = thisOp.getCoordinates().size();
2247 llvm::Intrinsic::ID
id = IDTable[mode][dim];
2248 if (
id == llvm::Intrinsic::not_intrinsic)
2250 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
2252 return {id, std::move(args)};
2257 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
2265 args.push_back(mt.
lookupValue(thisOp.getSrcMem()));
2266 args.push_back(mt.
lookupValue(thisOp.getTmaDescriptor()));
2268 for (
Value v : thisOp.getCoordinates())
2272 const bool hasCacheHint =
static_cast<bool>(cacheHint);
2273 llvm::Value *i64ZeroValue =
2274 llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2275 args.push_back(hasCacheHint ? mt.
lookupValue(cacheHint) : i64ZeroValue);
2276 args.push_back(builder.getInt1(hasCacheHint));
2278 const llvm::Intrinsic::ID
notIntrinsic = llvm::Intrinsic::not_intrinsic;
2280 constexpr unsigned numRedKinds = 8;
2281 constexpr unsigned numLayouts = 2;
2282 constexpr unsigned maxDim = 5;
2283 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
2284 using layoutTable = std::array<row, numLayouts>;
2285 using fullTable = std::array<layoutTable, numRedKinds>;
2286 static constexpr fullTable IDTable{
2289 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
2290 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
2291 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
2292 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
2293 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
2295 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
2296 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
2297 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
2300 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
2301 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
2302 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
2303 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
2304 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
2306 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
2307 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
2308 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
2311 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
2312 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
2313 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
2314 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
2315 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
2317 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
2318 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
2319 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
2322 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
2323 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
2324 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
2325 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
2326 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
2328 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
2329 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
2330 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
2333 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
2334 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
2335 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
2336 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
2337 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
2339 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
2340 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
2341 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
2344 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
2345 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
2346 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
2347 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
2348 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
2350 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
2351 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
2352 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
2355 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
2356 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
2357 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
2358 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
2359 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
2361 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
2362 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
2363 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
2366 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
2367 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
2368 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
2369 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
2370 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
2372 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
2373 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
2375 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2377 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2378 "TMAReduxKinds must match number of rows in IDTable");
2380 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
2381 size_t mode =
static_cast<size_t>(thisOp.getMode());
2382 size_t dim = thisOp.getCoordinates().size();
2384 assert(redKind < IDTable.size() &&
2385 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2386 assert(mode < IDTable[redKind].size() &&
2387 "Invalid mode for CpAsyncBulkTensorReduceOp");
2388 assert(dim < IDTable[redKind][mode].size() &&
2389 "Invalid dim for CpAsyncBulkTensorReduceOp");
2391 llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
2394 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2396 return {intrinsicID, std::move(args)};
2401#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2402 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2403 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2405#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2406 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2407 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2410ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2411 NVVM::SaturationMode sat,
bool hasRelu) {
2412 using RndMode = NVVM::FPRoundingMode;
2413 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2422 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
2427ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
2429 llvm::IRBuilderBase &builder) {
2434 bool hasRelu = op.getRelu();
2436 llvm::Intrinsic::ID intId =
2437 hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
2438 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
2440 return {intId, std::move(args)};
2443#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2444 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2445 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2447llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(
mlir::Type dstTy,
2450 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2453 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2457 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
2458 return llvm::Intrinsic::not_intrinsic;
2462#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2463 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2464 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2466#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2467 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2468 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2471ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2472 NVVM::SaturationMode sat,
bool hasRelu) {
2473 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2474 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2475 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2478 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2481 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2484 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2485 if (hasRoundingModeRZ)
2487 else if (hasRoundingModeRP)
2490 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2493 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2494 return llvm::Intrinsic::not_intrinsic;
2498#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2499 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2500 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2502llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy,
2505 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2508 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2512 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
2513 return llvm::Intrinsic::not_intrinsic;
2517#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2518 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2519 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2522ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2523 NVVM::SaturationMode sat) {
2524 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2526 case NVVM::FPRoundingMode::RZ:
2528 case NVVM::FPRoundingMode::RP:
2531 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
2537 auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
2539 bool hasRelu = curOp.getRelu();
2541 llvm::Intrinsic::ID intId =
2543 .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
2544 return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
2545 : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
2547 .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
2548 return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
2549 : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
2552 llvm_unreachable(
"Invalid type for ConvertF8x2ToF16x2Op");
2553 return llvm::Intrinsic::not_intrinsic;
2556 llvm::Value *packedI16 =
2557 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2558 llvm::Type::getInt16Ty(builder.getContext()));
2560 return {intId, {packedI16}};
2565 auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
2567 llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
2568 llvm::Value *packedI16 =
2569 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2570 llvm::Type::getInt16Ty(builder.getContext()));
2572 return {intId, {packedI16}};
2577 auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
2579 bool hasRelu = curOp.getRelu();
2581 llvm::Intrinsic::ID intId =
2583 .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
2584 return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
2585 : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
2587 .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
2588 return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
2589 : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
2592 llvm_unreachable(
"Invalid type for ConvertF6x2ToF16x2Op");
2593 return llvm::Intrinsic::not_intrinsic;
2596 llvm::Value *packedI16 =
2597 builder.CreateBitCast(mt.
lookupValue(curOp.getSrc()),
2598 llvm::Type::getInt16Ty(builder.getContext()));
2600 return {intId, {packedI16}};
2605 auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
2607 bool hasRelu = curOp.getRelu();
2609 llvm::Intrinsic::ID intId =
2611 .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
2612 return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
2613 : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
2616 llvm_unreachable(
"Invalid type for ConvertF4x2ToF16x2Op");
2617 return llvm::Intrinsic::not_intrinsic;
2620 llvm::Value *extendedI16 =
2621 builder.CreateZExt(mt.
lookupValue(curOp.getSrc()),
2622 llvm::Type::getInt16Ty(builder.getContext()));
2624 return {intId, {extendedI16}};
2628Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
2631 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2632 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2634 bool isShared = as == NVVMMemorySpace::Shared;
2635 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2637 llvm::Intrinsic::ID id;
2639 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2640 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2642 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2643 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2653llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
2656 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2657 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2658 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2659 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2668#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2669 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2670 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2672#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2673 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2674 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2677Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
2680 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2681 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2683 bool isShared = as == NVVMMemorySpace::Shared;
2684 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
2685 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2687 llvm::Intrinsic::ID
id =
2694 args.push_back(mt.
lookupValue(curOp.getMulticastMask()));
2699#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2700 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2702#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2703 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2704 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2706#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2708 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2709 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2710 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2711 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2712 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2715llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
2716 bool hasRelu = getRelu();
2717 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2719 if (hasRelu && hasSatFinite)
2720 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
2722 return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
2724 return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
2725 return llvm::Intrinsic::nvvm_ff2f16x2_rs;
2728llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
2729 bool hasRelu = getRelu();
2730 bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
2732 if (hasRelu && hasSatFinite)
2733 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
2735 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
2737 return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
2738 return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
2741llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
2743 bool hasRelu = getRelu();
2746 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2747 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite
2748 : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite;
2750 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2751 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite
2752 : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite;
2755 llvm_unreachable(
"Invalid F8 type in ConvertF32x4ToF8x4Op");
2756 return llvm::Intrinsic::not_intrinsic;
2760llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() {
2762 bool hasRelu = getRelu();
2765 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2766 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite
2767 : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite;
2769 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2770 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite
2771 : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite;
2774 llvm_unreachable(
"Invalid F6 type in ConvertF32x4ToF6x4Op");
2775 return llvm::Intrinsic::not_intrinsic;
2779llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() {
2781 bool hasRelu = getRelu();
2784 .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) {
2785 return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite
2786 : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite;
2789 llvm_unreachable(
"Invalid F4 type in ConvertF32x4ToF4x4Op");
2790 return llvm::Intrinsic::not_intrinsic;
2794llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(
Operation &op) {
2795 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2796 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2797 auto srcFmt = curOp.getSrcFormat();
2798 auto mc = curOp.getMulticast();
2800 switch (curOp.getShape()) {
2801 case Tcgen05CpShape::SHAPE_128x256b:
2803 case Tcgen05CpShape::SHAPE_128x128b:
2805 case Tcgen05CpShape::SHAPE_4x256b:
2807 case Tcgen05CpShape::SHAPE_32x128b:
2809 case Tcgen05CpShape::SHAPE_64x128b:
2810 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2814 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
2821 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2823 if (
shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2828LogicalResult Tcgen05LdOp::verify() {
2830 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2833 if (
getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
2834 result =
emitError(
"offset argument is only supported for shape 16x32bx2");
2836 auto resTy = getRes().getType();
2837 unsigned resLen = isa<VectorType>(resTy)
2838 ? llvm::cast<VectorType>(resTy).getNumElements()
2841 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
2842 "{1} in tcgen05.ld Op",
2843 resLen, stringifyEnum(
getShape())));
2848LogicalResult Tcgen05StOp::verify() {
2850 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2853 auto valTy = getVal().getType();
2854 unsigned valLen = isa<VectorType>(valTy)
2855 ? llvm::cast<VectorType>(valTy).getNumElements()
2858 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
2859 "{1} in tcgen05.st Op",
2860 valLen, stringifyEnum(
getShape())));
2870 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
2871 setResultRanges(
result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2872 rangeAttr.getLower(), rangeAttr.getUpper()});
2880 std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
2884 const llvm::APInt &lower = rangeAttr->getLower();
2885 const llvm::APInt &upper = rangeAttr->getUpper();
2888 if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
2889 unsigned bitWidth = lower.getBitWidth();
2890 llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
2891 llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
2893 "invalid range attribute: Lower == Upper, but they aren't min (")
2894 << llvm::toString(minVal, 10,
false) <<
") or max ("
2895 << llvm::toString(maxVal, 10,
false)
2896 <<
") value! This is an invalid constant range.";
2903 llvm::IRBuilderBase &builder) {
2904 return builder.CreateBitCast(arg,
2905 llvm::Type::getInt32Ty(builder.getContext()));
2910 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2917 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2918 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2919 unsigned type = (isASigned << 1) | isBSigned;
2920 const llvm::Intrinsic::ID ids[] = {
2921 llvm::Intrinsic::nvvm_idp4a_u_u,
2922 llvm::Intrinsic::nvvm_idp4a_u_s,
2923 llvm::Intrinsic::nvvm_idp4a_s_u,
2924 llvm::Intrinsic::nvvm_idp4a_s_s,
2926 return {ids[type], args};
2931 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2936 args.push_back(builder.getInt1(curOp.getBHi()));
2939 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2940 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2941 unsigned type = (isASigned << 1) | isBSigned;
2942 const llvm::Intrinsic::ID ids[] = {
2943 llvm::Intrinsic::nvvm_idp2a_u_u,
2944 llvm::Intrinsic::nvvm_idp2a_u_s,
2945 llvm::Intrinsic::nvvm_idp2a_s_u,
2946 llvm::Intrinsic::nvvm_idp2a_s_s,
2948 return {ids[type], args};
2952 llvm::IRBuilderBase &builder) {
2953 return builder.CreateAddrSpaceCast(
2955 llvm::PointerType::get(builder.getContext(),
2956 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2960PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2962 llvm::IRBuilderBase &builder) {
2963 using MemSpace = NVVM::NVVMMemorySpace;
2964 using CacheLevel = NVVM::PrefetchCacheLevel;
2966 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2967 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2968 op.getEvictPriority();
2969 unsigned addressSpace =
2970 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2978 if (op.getTensormap())
2979 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2981 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
2983 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2984 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2986 if (evictPriority && *cacheLevel == CacheLevel::L2) {
2987 switch (*evictPriority) {
2988 case NVVM::CacheEvictionPriority::EvictLast:
2989 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2990 case NVVM::CacheEvictionPriority::EvictNormal:
2991 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2993 llvm_unreachable(
"Invalid cache eviction priority");
2997 switch (
static_cast<MemSpace
>(addressSpace)) {
2998 case MemSpace::Generic:
2999 return *cacheLevel == CacheLevel::L1
3001 :
NVVM::
IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
3002 case MemSpace::Global:
3003 return *cacheLevel == CacheLevel::L1
3005 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
3007 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
3008 case MemSpace::Local:
3009 return *cacheLevel == CacheLevel::L1
3011 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
3013 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
3015 llvm_unreachable(
"Invalid pointer address space");
3019bool NVVM::InlinePtxOp::getAsmValues(
3023 for (
auto arg : getReadWriteArgs())
3025 for (
auto arg : getResults())
3027 for (
auto arg : getReadOnlyArgs())
3034NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
3036 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
3038 args.push_back(mt.
lookupValue(curOp.getSmemAddress()));
3039 args.push_back(mt.
lookupValue(curOp.getMbarrier()));
3041 llvm::Intrinsic::ID intrinsicID =
3042 curOp.getMulticast()
3044 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
3045 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
3047 return {intrinsicID, args};
3050NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
3052 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
3054 args.push_back(mt.
lookupValue(curOp.getTryCancelResponse()));
3056 llvm::Intrinsic::ID intrinsicID;
3058 switch (curOp.getQueryType()) {
3059 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
3061 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
3063 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
3064 intrinsicID = llvm::Intrinsic::
3065 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
3067 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
3068 intrinsicID = llvm::Intrinsic::
3069 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
3071 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
3072 intrinsicID = llvm::Intrinsic::
3073 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
3076 return {intrinsicID, args};
3084void NVVMDialect::initialize() {
3087#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3090#define GET_ATTRDEF_LIST
3091#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
3096 allowUnknownOperations();
3097 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
3098 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
3101LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
3103 StringAttr attrName = attr.
getName();
3105 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
3106 if (!isa<LLVM::LLVMFuncOp>(op)) {
3107 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
3108 <<
"' attribute attached to unexpected op";
3113 if (attrName == NVVMDialect::getMaxntidAttrName() ||
3114 attrName == NVVMDialect::getReqntidAttrName() ||
3115 attrName == NVVMDialect::getClusterDimAttrName()) {
3116 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
3117 if (!values || values.empty() || values.size() > 3) {
3120 <<
"' attribute must be integer array with maximum 3 index";
3125 if (attrName == NVVMDialect::getMinctasmAttrName() ||
3126 attrName == NVVMDialect::getMaxnregAttrName() ||
3127 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
3128 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
3130 <<
"'" << attrName <<
"' attribute must be integer constant";
3134 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
3135 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
3136 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
3138 <<
"'" << attrName <<
"' attribute must be used along with "
3139 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
3140 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
3147LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
3148 unsigned regionIndex,
3151 auto funcOp = dyn_cast<FunctionOpInterface>(op);
3155 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
3156 StringAttr attrName = argAttr.
getName();
3157 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
3161 <<
"' attribute must be present only on kernel arguments";
3163 if (!isa<UnitAttr>(argAttr.
getValue()))
3164 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
3165 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
3168 <<
"' attribute requires the argument to also have attribute '"
3169 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
3180unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
3181 return static_cast<unsigned>(getValue());
3184bool NVVMMemorySpaceAttr::isValidLoad(
3185 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3186 const ::mlir::DataLayout *dataLayout,
3192bool NVVMMemorySpaceAttr::isValidStore(
3193 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
3194 const ::mlir::DataLayout *dataLayout,
3200bool NVVMMemorySpaceAttr::isValidAtomicOp(
3201 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
3202 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
3205 assert(
false &&
"unimplemented, see TODO in the source.");
3209bool NVVMMemorySpaceAttr::isValidAtomicXchg(
3210 Type type, ptr::AtomicOrdering successOrdering,
3211 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
3212 const ::mlir::DataLayout *dataLayout,
3215 assert(
false &&
"unimplemented, see TODO in the source.");
3219bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
3223 assert(
false &&
"unimplemented, see TODO in the source.");
3227bool NVVMMemorySpaceAttr::isValidPtrIntCast(
3232 assert(
false &&
"unimplemented, see TODO in the source.");
3241 int optLevel, StringRef triple, StringRef chip,
3242 StringRef features, DictionaryAttr flags,
3244 if (optLevel < 0 || optLevel > 3) {
3245 emitError() <<
"The optimization level must be a number between 0 and 3.";
3248 if (triple.empty()) {
3249 emitError() <<
"The target triple cannot be empty.";
3253 emitError() <<
"The target chip cannot be empty.";
3256 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
3257 return mlir::isa_and_nonnull<StringAttr>(attr);
3259 emitError() <<
"All the elements in the `link` array must be strings.";
3265LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
3266 if (!getVerifyTarget())
3269 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
3272 "NVVM target attribute must be attached to a GPU module");
3275 const NVVMCheckSMVersion targetSMVersion =
3279 "Minimum NVVM target SM version is sm_20");
3282 gpuModuleOp->walk([&](Operation *op) {
3283 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
3284 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
3286 op->
emitOpError() <<
"is not supported on " << getChip();
3296#define GET_OP_CLASSES
3297#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
3299#define GET_ATTRDEF_CLASSES
3300#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace(mlir::Value ptr)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
static LogicalResult verifyConstantRangeAttr(Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRang...
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static constexpr unsigned notIntrinsic
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
static const NVVMCheckSMVersion getTargetSMVersionFromStr(StringRef smVersionString)
This represents an operation in an abstracted form, suitable for use with the builder APIs.