31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/IRBuilder.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/NVPTXAddrSpace.h"
37 #include "llvm/Support/raw_ostream.h"
45 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
48 static constexpr
unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
59 size_t numIm2ColOffsets,
61 if (tensorDims < 1 || tensorDims > 5)
62 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
70 "to use im2col mode, the tensor has to be at least 3-dimensional");
72 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
74 loc,
"im2col offsets must be 2 less than number of coordinates");
80 TMAStoreMode mode = getMode();
85 if (mode != TMAStoreMode::TILE)
86 return emitError(
"Inline-ptx lowering supported only for Tile mode.");
88 return emitError(
"Inline-ptx lowering unsupported with L2 cache-hint.");
93 case TMAStoreMode::TILE:
95 case TMAStoreMode::IM2COL:
97 case TMAStoreMode::TILE_SCATTER4:
99 return emitError(
"Scatter4 mode expects 5 coordinates");
105 if (getModifier() != LoadCacheModifierKind::CG &&
106 getModifier() != LoadCacheModifierKind::CA)
107 return emitError(
"Only CG and CA cache modifiers are supported.");
108 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
109 return emitError(
"expected byte size to be either 4, 8 or 16.");
110 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
111 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
118 if (tensorDims < 1 || tensorDims > 5)
119 return emitError(loc,
"expects coordinates between 1 to 5 dimension");
121 auto checkTMALoadParams = [&](TMALoadMode mode,
bool isIm2col,
122 size_t expectedIm2colOff) -> LogicalResult {
123 if (isIm2col && (tensorDims < 3))
125 <<
"to use " << stringifyEnum(mode)
126 <<
" mode, the tensor has to be at least 3-dimensional";
128 if (numIm2colOff != expectedIm2colOff)
129 return emitError(loc) <<
" im2col offsets expected " << expectedIm2colOff
130 <<
" (provided " << numIm2colOff <<
")";
136 case TMALoadMode::TILE:
137 return checkTMALoadParams(mode,
false, 0);
138 case TMALoadMode::IM2COL:
139 return checkTMALoadParams(mode,
true, tensorDims - 2);
140 case TMALoadMode::IM2COL_W:
141 case TMALoadMode::IM2COL_W_128:
142 return checkTMALoadParams(mode,
true, 2);
143 case TMALoadMode::TILE_GATHER4:
144 return (tensorDims == 5)
145 ? checkTMALoadParams(mode,
false, 0)
146 :
emitError(loc,
"Gather4 mode expects 5 coordinates");
153 getMode(), getLoc());
157 TMALoadMode mode = getMode();
158 bool isCTAOnly = getIsCTAOnly();
159 if (getPredicate()) {
161 return emitError(
"Predicate is supported only for shared::cluster mode.");
162 if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
164 "Predicate is supported only for Tile and Im2col modes.");
166 NVVMMemorySpace expectedAS =
167 isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168 unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().
getType())
170 if (AS != expectedAS)
173 ?
"Shared::cta destination requires address-space 3."
174 :
"Shared::cluster destination requires address-space 7.");
177 if (getMulticastMask())
178 return emitError(
"Multicast is not supported with shared::cta mode.");
180 return emitError(
"CTAGroup is not supported with shared::cta mode.");
185 getMode(), getLoc());
189 TMAStoreMode mode = getMode();
192 case TMAStoreMode::TILE:
194 case TMAStoreMode::IM2COL:
196 case TMAStoreMode::TILE_SCATTER4:
197 return emitError(
"Scatter mode unsupported for CpAsyncBulkTensorReduceOp");
203 using RndMode = NVVM::FPRoundingMode;
207 return emitError(
"Relu not supported with rna rounding mode.");
214 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
222 if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
223 return emitOpError(
"Only ")
226 <<
" types are supported for conversions from f32x2 to f6x2.";
232 using RndMode = NVVM::FPRoundingMode;
233 using SatMode = NVVM::SaturationMode;
235 bool isRoundingModeRN = getRnd() == RndMode::RN;
236 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
237 bool isRoundingModeRP = getRnd() == RndMode::RP;
238 bool isSatFinite = getSat() == SatMode::SATFINITE;
240 bool hasRelu = getRelu();
245 .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
247 if (!isRoundingModeRN) {
248 return emitOpError(
"Only RN rounding mode is supported for "
249 "conversions from f32x2 to ")
254 return emitOpError(
"Only SATFINITE saturation mode is supported "
262 .Case<mlir::Float8E8M0FNUType>([&](
mlir::Type) -> LogicalResult {
263 if (!(isRoundingModeRZ || isRoundingModeRP)) {
264 return emitOpError(
"Only RZ and RP rounding modes are supported for "
265 "conversions from f32x2 to ")
269 return emitOpError(
"relu not supported for conversions to ")
275 return emitOpError(
"Only ")
280 "supported for conversions from f32x2 to f8x2";
287 if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) {
288 return emitOpError(
"Only ")
291 <<
" types are supported for conversions from f16x2 to f8x2.";
297 using RndMode = NVVM::FPRoundingMode;
299 if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
301 <<
" type is supported for conversions from "
305 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
306 return emitOpError(
"Only RZ and RP rounding modes are supported for "
307 "conversions from bf16x2 to f8x2.");
313 if (getInitVal() != 0)
314 return emitOpError(
"only 0 is supported for initVal, got ") << getInitVal();
319 auto eventId = getEventId();
320 auto maskedEventId = getMaskedEventId();
321 if (!maskedEventId && !eventId) {
322 return emitOpError() <<
"either `id` or `mask` must be set";
325 if (maskedEventId && eventId) {
326 return emitOpError() <<
"`id` and `mask` cannot be set at the same time";
330 if (eventId < 0 || eventId > 15) {
331 return emitOpError() <<
"`id` must be between 0 and 15";
335 return llvm::success();
341 std::optional<mlir::NVVM::MMATypes>
342 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
345 if (operandElType.
isF64())
346 return NVVM::MMATypes::f64;
347 if (operandElType.
isF16() || operandElType == half2Type)
348 return NVVM::MMATypes::f16;
349 if (operandElType.
isF32() && isAccumulator)
350 return NVVM::MMATypes::f32;
351 if (operandElType.
isF32() && !isAccumulator)
352 return NVVM::MMATypes::tf32;
353 if (llvm::isa<IntegerType>(operandElType)) {
355 return NVVM::MMATypes::s32;
359 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
360 if (structType.getBody().empty())
362 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
369 return (type == MMATypes::u4 || type == MMATypes::s4);
373 return (type == MMATypes::u8 || type == MMATypes::s8);
378 type == MMATypes::s32;
381 MMATypes MmaOp::accumPtxType() {
382 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
383 getODSOperands(2).getTypes().front(),
true);
384 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
388 MMATypes MmaOp::resultPtxType() {
389 std::optional<mlir::NVVM::MMATypes> val =
390 inferOperandMMAType(getResult().
getType(),
true);
391 assert(val.has_value() &&
"result PTX type should always be inferrable");
397 struct OperandFragment {
398 StringRef operandName;
399 StringRef ptxTypeAttr;
401 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
402 : operandName(name), ptxTypeAttr(ptxTypeName) {}
405 std::array<OperandFragment, 3> frags{
406 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
407 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
408 OperandFragment(
"C",
"")};
410 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
412 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
413 auto &frag = frags[fragIdx];
414 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
415 for (
auto operandIdx = varOperandSpec.first;
416 operandIdx < varOperandSpec.first + varOperandSpec.second;
418 frag.regs.push_back(this->getOperand(operandIdx));
419 if (operandIdx == 0) {
420 regTypes.push_back(this->getOperand(operandIdx).
getType());
423 std::optional<MMATypes> inferredType =
424 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
426 ignoreAttrNames.push_back(frag.ptxTypeAttr);
429 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
430 p <<
" " << frag.operandName;
436 for (
const auto &frag : frags) {
437 printMmaOperand(frag);
456 std::optional<MMAIntOverflow> intOverflow,
457 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
458 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
460 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
463 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
469 if (multiplicandPtxTypes) {
475 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
477 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
481 if (multiplicandLayouts) {
491 if (intOverflow.has_value())
494 if (b1Op.has_value())
499 MmaOp::getOperandSegmentSizeAttr(),
501 static_cast<int32_t>(operandB.size()),
502 static_cast<int32_t>(operandC.size())}));
510 struct OperandFragment {
511 std::optional<MMATypes> elemtype;
517 std::array<OperandFragment, 4> frags;
522 auto parseMmaOperand = [&](StringRef operandName,
523 OperandFragment &frag) -> LogicalResult {
534 if (parseMmaOperand(
"A", frags[0]).
failed())
536 if (parseMmaOperand(
"B", frags[1]).
failed())
538 if (parseMmaOperand(
"C", frags[2]).
failed())
553 if (operandTypes.size() != 3)
556 "expected one type for each operand segment but got " +
557 Twine(operandTypes.size()) +
" types");
559 auto &frag = frags[iter.index()];
560 frag.regTypes.resize(frag.regs.size(), iter.value());
564 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
571 frags[3].elemtype = inferOperandMMAType(resultType,
true);
573 std::array<StringRef, 2> names{
"multiplicandAPtxType",
574 "multiplicandBPtxType"};
575 for (
unsigned idx = 0; idx < names.size(); idx++) {
576 const auto &frag = frags[idx];
577 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
578 if (!frag.elemtype.has_value() && !attr.has_value()) {
581 "attribute " + names[idx] +
582 " is not provided explicitly and cannot be inferred");
584 if (!attr.has_value())
590 if (!namedAttributes.
empty())
594 static_cast<int32_t>(frags[0].regs.size()),
595 static_cast<int32_t>(frags[1].regs.size()),
596 static_cast<int32_t>(frags[2].regs.size()),
607 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
608 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
611 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
614 auto f16x2x2StructTy =
615 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
617 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
619 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
621 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
622 getShapeAttr().getK()};
628 AllowedShapes allowedShapes;
629 AllowedTypes expectedA;
630 AllowedTypes expectedB;
631 AllowedTypes expectedC;
636 if (mmaShape[0] == 16) {
638 Type multiplicandFragType;
639 switch (*getMultiplicandAPtxType()) {
642 multiplicandFragType = i32Ty;
643 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
644 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
648 multiplicandFragType = i32Ty;
649 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
650 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
654 multiplicandFragType = f16x2Ty;
655 expectedResult.push_back(f16x2x2StructTy);
656 expectedResult.push_back(f32x4StructTy);
670 return emitError(
"invalid shape or multiplicand type: " +
671 stringifyEnum(getMultiplicandAPtxType().value()));
675 expectedResult.push_back(s32x4StructTy);
676 expectedC.emplace_back(4, i32Ty);
677 multiplicandFragType = i32Ty;
679 expectedC.emplace_back(2, f16x2Ty);
680 expectedC.emplace_back(4, f32Ty);
683 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
684 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
685 expectedA.emplace_back(unitA, multiplicandFragType);
686 expectedB.emplace_back(unitB, multiplicandFragType);
687 allowedShapes.push_back({16, 8, kFactor});
688 allowedShapes.push_back({16, 8, kFactor * 2});
690 if (resultPtxType() != accumPtxType())
691 return emitOpError(
"ctype does not match dtype");
695 if (mmaShape[0] == 8) {
696 if (*getMultiplicandAPtxType() == MMATypes::f16) {
697 expectedA.emplace_back(2, f16x2Ty);
698 expectedB.emplace_back(2, f16x2Ty);
699 expectedResult.push_back(f16x2x4StructTy);
700 expectedResult.push_back(f32x8StructTy);
701 expectedC.emplace_back(4, f16x2Ty);
702 expectedC.emplace_back(8, f32Ty);
703 allowedShapes.push_back({8, 8, 4});
705 if (*getMultiplicandAPtxType() == MMATypes::f64) {
707 expectedA.emplace_back(1, f64Ty);
708 expectedB.emplace_back(1, f64Ty);
709 expectedC.emplace_back(2, f64Ty);
710 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
712 allowedShapes.push_back({8, 8, 4});
715 expectedA.push_back({i32Ty});
716 expectedB.push_back({i32Ty});
717 expectedC.push_back({i32Ty, i32Ty});
718 expectedResult.push_back(s32x2StructTy);
720 allowedShapes.push_back({8, 8, 32});
722 allowedShapes.push_back({8, 8, 16});
723 if (getMultiplicandAPtxType().value() == MMATypes::b1)
724 allowedShapes.push_back({8, 8, 128});
728 std::string errorMessage;
729 llvm::raw_string_ostream errorStream(errorMessage);
732 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
733 !llvm::is_contained(allowedShapes, mmaShape)) {
734 errorStream <<
"unimplemented variant for MMA shape <";
735 llvm::interleaveComma(mmaShape, errorStream);
737 return emitOpError(errorMessage);
741 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
744 auto spec = this->getODSOperandIndexAndLength(iter.index());
746 operand_type_begin() + spec.first +
748 bool match = llvm::is_contained(iter.value(), operandTySeg);
751 errorStream <<
"Could not match types for the "
752 << operandNames[iter.index()]
753 <<
" operands; expected one of ";
754 for (
const auto &x : iter.value()) {
755 errorStream << x.size() <<
"x" << x[0] <<
" ";
757 errorStream <<
"but got ";
758 llvm::interleaveComma(operandTySeg, errorStream);
759 return emitOpError(errorMessage);
764 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
765 return expectedResultType == getResult().getType();
768 <<
"Could not match allowed types for the result; expected one of ";
769 llvm::interleaveComma(expectedResult, errorStream);
770 errorStream <<
" but got " << getResult().getType();
771 return emitOpError(errorMessage);
775 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
776 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
784 if (!getIntOverflowBehavior())
785 return emitOpError(
"op requires " +
786 getIntOverflowBehaviorAttrName().strref() +
794 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
796 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
797 auto elementType = (type && type.getBody().size() == 2)
798 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
800 if (!elementType || elementType.getWidth() != 1)
801 return emitError(
"expected return type to be a two-element struct with "
802 "i1 as the second element");
807 NVVM::MMAFrag frag,
int nRow,
810 unsigned numberElements = 0;
814 if (type == NVVM::MMATypes::f16) {
816 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
820 }
else if (type == NVVM::MMATypes::f32) {
823 }
else if (type == NVVM::MMATypes::tf32) {
826 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
828 int parallelSize = 0;
829 if (frag == NVVM::MMAFrag::a)
831 if (frag == NVVM::MMAFrag::b)
835 if (parallelSize == 16)
838 else if (parallelSize == 8)
840 else if (parallelSize == 32)
842 }
else if (type == NVVM::MMATypes::s32) {
846 assert(numberElements != 0 && elementType !=
nullptr);
847 return std::make_pair(elementType, numberElements);
850 static std::pair<mlir::Type, unsigned>
854 if (frag == NVVM::MMAFrag::a) {
857 }
else if (frag == NVVM::MMAFrag::b) {
864 assert(nRow && nCol);
869 unsigned addressSpace =
870 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
871 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
872 addressSpace != NVVMMemorySpace::Shared)
873 return emitOpError(
"expected source pointer in memory "
876 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
877 getEltype(), getFrag()) == 0)
878 return emitOpError() <<
"invalid attribute combination";
881 Type dstType = LLVM::LLVMStructType::getLiteral(
884 return emitOpError(
"expected destination type is a structure of ")
885 << typeInfo.second <<
" elements of type " << typeInfo.first;
890 unsigned addressSpace =
891 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
892 if (addressSpace != 0 && addressSpace != NVVMMemorySpace::Global &&
893 addressSpace != NVVMMemorySpace::Shared)
894 return emitOpError(
"expected operands to be a source pointer in memory "
897 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
899 return emitOpError() <<
"invalid attribute combination";
902 if (getArgs().size() != typeInfo.second)
903 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
904 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
905 return operands.
getType() != typeInfo.first;
907 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
912 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
913 getLayoutB(), getEltypeA(),
915 return emitOpError() <<
"invalid attribute combination";
923 arguments.append(typeInfoA.second, typeInfoA.first);
924 arguments.append(typeInfoB.second, typeInfoB.first);
925 arguments.append(typeInfoC.second, typeInfoC.first);
926 unsigned numArgs = arguments.size();
927 if (getArgs().size() != numArgs)
928 return emitOpError() <<
"expected " << numArgs <<
" arguments";
929 for (
unsigned i = 0; i < numArgs; i++) {
930 if (getArgs()[i].
getType() != arguments[i])
931 return emitOpError() <<
"expected argument " << i <<
" to be of type "
934 Type dstType = LLVM::LLVMStructType::getLiteral(
937 return emitOpError(
"expected destination type is a structure of ")
938 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
944 if (m == 8 && n == 8) {
945 if (num != 1 && num != 2 && num != 4) {
946 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x8 "
949 if (getEltType() != LdStMatrixEltType::B16) {
950 return emitOpError(
"expected element type to be b16 for 8x8 matrix");
952 }
else if (m == 8 && n == 16) {
953 if (num != 1 && num != 2 && num != 4) {
954 return emitOpError(
"expected num attribute to be 1, 2 or 4 for 8x16 "
957 if (getLayout() != MMALayout::row) {
958 return emitOpError(
"expected layout to be row for 8x16 matrix");
960 if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
961 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
962 return emitOpError(
"expected element type to be b8x16.b4x16_p64 or "
963 "b8x16.b6x16_p32 for 8x16 matrix");
965 }
else if (m == 16 && n == 16) {
966 if (num != 1 && num != 2) {
967 return emitOpError(
"expected num attribute to be 1 or 2 for 16x16 "
970 if (getLayout() != MMALayout::col) {
971 return emitOpError(
"expected layout to be col for 16x16 matrix");
973 if (getEltType() != LdStMatrixEltType::B8 &&
974 getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
975 getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
976 return emitOpError(
"expected element type to be b8, b8x16.b4x16_p64 or "
977 "b8x16.b6x16_p32 for 16x16 matrix");
980 return emitOpError(
"expected shape to be 8x8, 8x16 or 16x16");
984 uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
985 if (numElements == 1 &&
getType() != i32)
986 return emitOpError(
"expected destination type is i32");
987 if (numElements == 2 || numElements == 4) {
988 Type dstType = LLVM::LLVMStructType::getLiteral(
991 return emitOpError(
"expected destination type is a structure of ")
992 << numElements <<
" elements of type i32";
999 int numMatrix = getSources().size();
1000 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
1001 return emitOpError(
"expected num attribute to be 1, 2 or 4");
1004 if (m == 8 && n == 8) {
1005 if (getEltType() != NVVM::LdStMatrixEltType::B16) {
1006 return emitOpError(
"expected element type to be B16 for 8x8 matrix");
1008 }
else if (m == 16 && n == 8) {
1009 if (getEltType() != NVVM::LdStMatrixEltType::B8) {
1010 return emitOpError(
"expected element type to be B8 for 16x8 matrix");
1012 if (getLayout() != NVVM::MMALayout::col) {
1013 return emitOpError(
"expected layout to be col for 16x8 matrix");
1016 return emitOpError(
"expected shape to be 8x8 or 16x8");
1023 if (typeA == NVVM::WGMMATypes::tf32)
1025 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
1027 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
1029 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
1031 if (typeA == NVVM::WGMMATypes::b1)
1037 NVVM::WGMMATypes typeA,
1038 NVVM::WGMMATypes typeB) {
1040 case NVVM::WGMMATypes::f16:
1041 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1042 typeB == NVVM::WGMMATypes::f16)
1045 case NVVM::WGMMATypes::tf32:
1046 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
1049 case NVVM::WGMMATypes::u8:
1050 case NVVM::WGMMATypes::s8:
1051 if (typeD == NVVM::WGMMATypes::s32 &&
1052 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
1055 case NVVM::WGMMATypes::b1:
1056 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
1059 case NVVM::WGMMATypes::bf16:
1060 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1061 typeB == NVVM::WGMMATypes::bf16)
1064 case NVVM::WGMMATypes::e4m3:
1065 case NVVM::WGMMATypes::e5m2:
1066 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
1067 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
1070 case WGMMATypes::f32:
1071 case WGMMATypes::s32:
1072 llvm_unreachable(
"unsupported input types");
1080 72, 80, 88, 96, 104, 112, 120, 128,
1081 136, 144, 152, 160, 168, 176, 184, 192,
1082 200, 208, 216, 224, 232, 240, 248, 256};
1084 80, 96, 112, 128, 144, 160,
1085 176, 192, 208, 224, 240, 256};
1087 case WGMMATypes::f16:
1088 case WGMMATypes::tf32:
1089 case WGMMATypes::bf16:
1090 case WGMMATypes::e4m3:
1091 case WGMMATypes::e5m2:
1092 if (llvm::is_contained(allowedN, sizeN))
1095 case WGMMATypes::u8:
1096 case WGMMATypes::s8:
1097 case WGMMATypes::b1:
1098 if (llvm::is_contained(allowedNshort, sizeN))
1101 case WGMMATypes::f32:
1102 case WGMMATypes::s32:
1103 llvm_unreachable(
"unsupported input types");
1110 Value outValue = getResults();
1111 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
1113 return emitOpError() <<
"expected results to be struct";
1114 int outputSize = stype.getBody().size();
1115 WGMMATypes typeD = getTypeD();
1116 WGMMATypes typeA = getTypeA();
1117 WGMMATypes typeB = getTypeB();
1119 for (
Type t : stype.getBody()) {
1120 if (t != stype.getBody().front())
1121 return emitOpError()
1122 <<
"all elements in struct must be same type but there is " << t;
1125 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
1126 typeD != WGMMATypes::s32) {
1127 return emitOpError() <<
"does not support the given output type "
1128 << NVVM::stringifyWGMMATypes(typeD);
1130 if (typeD == WGMMATypes::s32 &&
1131 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
1132 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
1136 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
1137 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
1138 << NVVM::stringifyWGMMATypes(typeB)
1139 <<
", it is not supported.";
1144 return emitOpError() <<
"shape 'm' must be 64";
1149 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
1150 <<
" for input type "
1151 << NVVM::stringifyWGMMATypes(typeA);
1155 return emitOpError() <<
"has input type "
1156 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
1157 <<
getShape().getN() <<
", it is not supported.";
1164 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
1165 (getLayoutA() == mlir::NVVM::MMALayout::col ||
1166 getLayoutB() == mlir::NVVM::MMALayout::row)) {
1167 return emitOpError()
1168 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
1169 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
1170 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
1171 << stringifyWGMMATypes(typeB)
1172 <<
" requires transpose. However, this is only supported for: "
1173 << stringifyMMATypes(MMATypes::f16) <<
" and "
1174 << stringifyMMATypes(MMATypes::bf16);
1178 int expectedOutput = 0;
1179 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
1180 expectedOutput =
getShape().getN() / 2;
1181 if (typeD == WGMMATypes::f16)
1182 expectedOutput =
getShape().getN() / 4;
1183 if (outputSize != expectedOutput) {
1184 return emitOpError() <<
"results " << expectedOutput
1185 <<
", however output struct has " << outputSize
1189 if (typeD != WGMMATypes::s32 &&
1190 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1191 NVVM::MMAIntOverflow::satfinite) {
1192 return emitOpError()
1193 <<
" `satfinite` can be only used with s32 accumulator, however "
1194 "the current accumulator is "
1195 << NVVM::stringifyWGMMATypes(typeD);
1201 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1204 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1206 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1208 int expectedOutputRegisters = 0;
1209 if (getTypeD() == WGMMATypes::f16)
1210 expectedOutputRegisters =
getShape().getN() / 4;
1212 expectedOutputRegisters =
getShape().getN() / 2;
1215 llvm::raw_string_ostream ss(ptx);
1220 << ((expectedOutputRegisters * 2) + 2)
1222 "wgmma.mma_async.sync.aligned.m"
1223 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
1224 << stringifyWGMMATypes(getTypeA()) <<
"."
1225 << stringifyWGMMATypes(getTypeB());
1226 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1227 NVVM::MMAIntOverflow::satfinite)
1231 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1232 ss <<
"$" << regCnt;
1233 if (regCnt != expectedOutputRegisters - 1)
1239 regCnt = (regCnt * 2);
1240 ss <<
" $" << (regCnt) <<
","
1241 <<
" $" << (regCnt + 1) <<
","
1243 if (getTypeD() != WGMMATypes::s32) {
1244 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
1248 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
1255 bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
1259 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1266 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
1268 if (getTypeD() != WGMMATypes::s32) {
1269 asmValues.push_back(
1270 {makeConstantI32(rewriter,
1271 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1273 asmValues.push_back(
1274 {makeConstantI32(rewriter,
1275 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1279 asmValues.push_back(
1280 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1282 asmValues.push_back(
1283 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1290 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1291 return emitOpError() <<
"tensormap proxy is not a supported proxy kind";
1292 if (getKind() == NVVM::ProxyKind::GENERIC)
1293 return emitOpError() <<
"generic proxy not a supported proxy kind";
1294 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1295 return emitOpError() <<
"async_shared fence requires space attribute";
1297 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1298 return emitOpError() <<
"only async_shared fence can have space attribute";
1304 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1305 return emitOpError(
"uni-directional proxies only support generic for "
1306 "from_proxy attribute");
1308 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1309 return emitOpError(
"uni-directional proxies only support tensormap "
1310 "for to_proxy attribute");
1316 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1317 return emitOpError(
"uni-directional proxies only support generic for "
1318 "from_proxy attribute");
1320 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1321 return emitOpError(
"uni-directional proxies only support tensormap "
1322 "for to_proxy attribute");
1328 if (getRegCount() % 8)
1329 return emitOpError(
"new register size must be multiple of 8");
1330 if (getRegCount() < 24 || getRegCount() > 256)
1331 return emitOpError(
"new register size must be in between 24 to 256");
1336 if (getNumberOfThreads() && !getBarrierId())
1338 "barrier id is missing, it should be set between 0 to 15");
1343 auto mc = getMulticast();
1345 using SH = Tcgen05CpShape;
1346 using MC = Tcgen05CpMulticast;
1348 case SH::SHAPE_128x256b:
1349 case SH::SHAPE_128x128b:
1350 case SH::SHAPE_4x256b:
1352 return emitError(
"Invalid multicast type for tcgen05.cp Op");
1354 case SH::SHAPE_64x128b:
1355 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1356 return emitError(
"Shape 64x128b requires multicast warpx2_01_23 or "
1357 "warpx2_02_13 for tcgen05.cp Op");
1359 case SH::SHAPE_32x128b:
1360 if (mc != MC::WARPX4)
1362 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1369 if (getKind() == NVVM::MatchSyncKind::all) {
1370 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
1371 if (!type || type.getBody().size() != 2 ||
1372 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1373 return emitOpError(
"match.sync 'all' returns a two element struct with "
1374 "first element as i32 and second element as i1");
1377 if (!
getType().isInteger(32)) {
1378 return emitOpError(
"match.sync 'any' returns an i32");
1385 if (getKind() == NVVM::VoteSyncKind::ballot) {
1386 if (!
getType().isInteger(32)) {
1387 return emitOpError(
"vote.sync 'ballot' returns an i32");
1390 if (!
getType().isInteger(1)) {
1391 return emitOpError(
"vote.sync 'any', 'all' and 'uni' returns an i1");
1398 using MemSpace = NVVM::NVVMMemorySpace;
1399 using CacheLevel = NVVM::PrefetchCacheLevel;
1401 unsigned addressSpace =
1402 llvm::cast<LLVM::LLVMPointerType>(getAddr().
getType()).getAddressSpace();
1403 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1404 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
1406 if (getTensormap() && cacheLevel)
1407 return emitOpError(
"cannot specify both tensormap and cache level");
1409 if (getTensormap()) {
1410 if (addressSpace != MemSpace::Generic &&
1411 addressSpace != MemSpace::Constant) {
1413 "prefetch tensormap requires a generic or constant pointer");
1416 if (evictPriority) {
1418 "prefetch tensormap does not support eviction priority");
1421 if (getInParamSpace() && addressSpace != MemSpace::Generic) {
1423 "in_param_space can only be specified for a generic pointer");
1426 }
else if (cacheLevel) {
1427 if (addressSpace != MemSpace::Generic && addressSpace != MemSpace::Global &&
1428 addressSpace != MemSpace::Local) {
1429 return emitOpError(
"prefetch to cache level requires a generic, global, "
1430 "or local pointer");
1434 if (*cacheLevel != CacheLevel::L1) {
1436 "unsupported cache level, the only supported uniform "
1437 "cache level is L1");
1440 if (addressSpace != MemSpace::Generic) {
1442 "prefetch to uniform cache requires a generic pointer");
1446 if (evictPriority) {
1447 if (*cacheLevel != CacheLevel::L2)
1449 "cache eviction priority supported only for cache level L2");
1451 if (addressSpace != MemSpace::Global)
1452 return emitOpError(
"cache eviction priority requires a global pointer");
1454 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1455 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1457 "unsupported cache eviction priority, only evict_last and "
1458 "evict_normal are supported");
1462 return emitOpError(
"predicate supported only on prefetch tensormap");
1466 "requires specification of either cache level or tensormap");
1473 switch (getQueryType()) {
1474 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
1476 return emitOpError(
"is_canceled query type returns an i1");
1478 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
1479 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
1480 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
1481 if (!
getType().isInteger(32)) {
1482 return emitOpError(
"get_first_cta_id_x, get_first_cta_id_y, "
1483 "get_first_cta_id_z query types return an i32");
1492 static llvm::Value *
1494 llvm::Value *result,
1496 unsigned sizeInBits,
1498 field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1500 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1501 if (mask != 0xffffffffu)
1502 field = builder.CreateAnd(field, builder.getInt32(mask));
1504 field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1505 field = builder.CreateShl(field, start);
1507 return builder.CreateOr(result, field);
1510 void Tcgen05MmaSmemDescOp::createSmemDescriptor(
Operation &op,
1511 LLVM::ModuleTranslation &mt,
1512 llvm::IRBuilderBase &builder) {
1513 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1514 llvm::Value *smemDesc = builder.getInt64(0);
1517 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1519 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1521 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1525 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1527 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1529 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1531 mt.mapValue(thisOp.getRes()) = smemDesc;
1538 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1539 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1541 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1542 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1545 CpAsyncOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt,
1549 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1550 bool hasCpSize =
static_cast<bool>(cpAsyncOp.getCpSize());
1551 switch (cpAsyncOp.getSize()) {
1559 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1564 llvm_unreachable(
"Invalid copy size in CpAsyncOp.");
1568 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1569 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1571 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1577 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1578 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1583 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1584 args.push_back(mt.lookupValue(thisOp.getSize()));
1587 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1588 llvm::Value *i64Unused =
1590 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1591 args.push_back(builder.getInt1(hasCacheHint));
1593 return {id, std::move(args)};
1597 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1598 auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1602 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1603 args.push_back(mt.lookupValue(thisOp.getMbar()));
1604 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1605 args.push_back(mt.lookupValue(thisOp.getSize()));
1608 mlir::Value multicastMask = thisOp.getMulticastMask();
1609 const bool hasMulticastMask =
static_cast<bool>(multicastMask);
1611 args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1615 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1617 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1620 args.push_back(builder.getInt1(hasMulticastMask));
1621 args.push_back(builder.getInt1(hasCacheHint));
1624 llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
1626 return {id, std::move(args)};
1630 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1631 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1634 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1637 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1638 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1639 args.push_back(mt.lookupValue(thisOp.getSize()));
1642 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1643 llvm::Value *i64Unused =
1645 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1646 args.push_back(builder.getInt1(hasCacheHint));
1649 if (
mlir::Value byteMask = thisOp.getByteMask()) {
1650 args.push_back(mt.lookupValue(byteMask));
1651 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1654 return {id, std::move(args)};
1657 bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1664 for (
auto val : getOperands())
1671 CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1672 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1673 auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1674 const bool isCTAOnly = thisOp.getIsCTAOnly();
1678 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1679 args.push_back(mt.lookupValue(thisOp.getMbar()));
1680 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1684 args.push_back(mt.lookupValue(v));
1686 args.push_back(mt.lookupValue(v));
1690 const bool hasMC =
static_cast<bool>(mcMask);
1691 llvm::Value *i16Zero =
1696 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1697 llvm::Value *i64Zero =
1704 thisOp.getGroup() ? (
static_cast<int32_t
>(*thisOp.getGroup()) + 1) : 0;
1710 args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1711 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1712 args.push_back(builder.getInt1(hasMC));
1713 args.push_back(builder.getInt1(hasCacheHint));
1717 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1718 args.push_back(builder.getInt1(hasCacheHint));
1721 constexpr
size_t numDims = 5;
1722 constexpr
size_t numModes = 5;
1723 using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1724 using TableTy = std::array<rowTy, numModes>;
1725 static constexpr TableTy IDTable{
1726 {{
notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1727 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1728 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1729 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1730 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1732 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1733 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1734 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1736 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1737 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1738 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1740 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1741 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1742 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1744 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1746 static constexpr TableTy IDTableCTA{
1748 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1749 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1750 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1751 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1752 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1754 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1755 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1756 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1758 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1759 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1760 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1762 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1763 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1764 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1766 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1769 (getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1770 (getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1771 "TMALoadModes must match number of rows in IDTable and IDTableCTA");
1772 size_t mode =
static_cast<size_t>(thisOp.getMode());
1773 size_t dim = thisOp.getCoordinates().size();
1774 auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1776 "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1778 return {id, std::move(args)};
1782 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1783 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
1787 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1789 for (
auto v : thisOp.getCoordinates())
1790 args.push_back(mt.lookupValue(v));
1791 for (
auto v : thisOp.getIm2colOffsets())
1792 args.push_back(mt.lookupValue(v));
1795 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1796 llvm::Value *i64Unused =
1798 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1799 args.push_back(builder.getInt1(hasCacheHint));
1801 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1803 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
1804 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
1805 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
1806 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
1807 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
1809 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
1810 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
1811 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
1813 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
1814 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
1815 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
1817 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
1818 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
1819 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
1820 {NI, NI, NI, NI, NI,
1821 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
1823 static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
1824 "TMALoadModes must match number of rows in IDTable");
1825 size_t mode =
static_cast<size_t>(thisOp.getMode());
1826 size_t dim = thisOp.getCoordinates().size();
1828 if (
id == llvm::Intrinsic::not_intrinsic)
1829 llvm_unreachable(
"Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
1831 return {id, std::move(args)};
1835 CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1836 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1837 auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op);
1841 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1842 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1844 for (
auto v : thisOp.getCoordinates())
1845 args.push_back(mt.lookupValue(v));
1848 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1849 llvm::Value *i64Unused =
1851 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1852 args.push_back(builder.getInt1(hasCacheHint));
1854 const unsigned NI = llvm::Intrinsic::not_intrinsic;
1856 {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d,
1857 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d,
1858 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d,
1859 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d,
1860 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d},
1861 {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d,
1862 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d,
1863 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d},
1864 {NI, NI, NI, NI, NI,
1865 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}};
1867 static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1,
1868 "TMAStoreModes must match number of rows in IDTable");
1869 size_t mode =
static_cast<size_t>(thisOp.getMode());
1870 size_t dim = thisOp.getCoordinates().size();
1872 if (
id == llvm::Intrinsic::not_intrinsic)
1874 "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp.");
1876 return {id, std::move(args)};
1880 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1881 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1882 llvm::LLVMContext &ctx = mt.getLLVMContext();
1889 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1890 args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1892 for (
Value v : thisOp.getCoordinates())
1893 args.push_back(mt.lookupValue(v));
1896 const bool hasCacheHint =
static_cast<bool>(cacheHint);
1897 llvm::Value *i64ZeroValue =
1899 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
1900 args.push_back(builder.getInt1(hasCacheHint));
1904 constexpr
unsigned numRedKinds = 8;
1905 constexpr
unsigned numLayouts = 2;
1906 constexpr
unsigned maxDim = 5;
1907 using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
1908 using layoutTable = std::array<row, numLayouts>;
1909 using fullTable = std::array<layoutTable, numRedKinds>;
1910 static constexpr fullTable IDTable{
1913 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1914 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1915 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1916 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1917 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
1919 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1920 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1921 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
1924 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1925 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1926 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1927 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1928 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
1930 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1931 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1932 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
1935 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1936 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1937 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1938 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1939 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
1941 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1942 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1943 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
1946 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1947 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1948 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1949 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1950 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
1952 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1953 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1954 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
1957 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1958 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1959 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1960 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1961 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
1963 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1964 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1965 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
1968 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1969 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1970 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1971 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1972 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
1974 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1975 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1976 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
1979 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1980 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1981 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1982 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1983 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
1985 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1986 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1987 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
1990 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1991 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1992 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1993 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1994 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
1996 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1997 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1999 nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
2001 static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
2002 "TMAReduxKinds must match number of rows in IDTable");
2004 size_t redKind =
static_cast<size_t>(thisOp.getRedKind());
2005 size_t mode =
static_cast<size_t>(thisOp.getMode());
2006 size_t dim = thisOp.getCoordinates().size();
2008 assert(redKind < IDTable.size() &&
2009 "Invalid redKind for CpAsyncBulkTensorReduceOp");
2010 assert(mode < IDTable[redKind].size() &&
2011 "Invalid mode for CpAsyncBulkTensorReduceOp");
2012 assert(dim < IDTable[redKind][mode].size() &&
2013 "Invalid dim for CpAsyncBulkTensorReduceOp");
2018 "Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
2020 return {intrinsicID, std::move(args)};
2025 #define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2026 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
2027 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
2029 #define GET_CVT_F2TF32_ID(rnd, relu, sf) \
2030 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
2031 : CVT_F2TF32_ID_IMPL(rnd, relu, )
2034 ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2035 NVVM::SaturationMode sat,
bool hasRelu) {
2036 using RndMode = NVVM::FPRoundingMode;
2037 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2046 llvm_unreachable(
"Invalid RoundingMode for CvtFloatToTF32Op");
2050 #define GET_F32x2_TO_F6x2_ID(type, has_relu) \
2051 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
2052 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
2057 .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
2060 .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
2064 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF6x2Op");
2065 return llvm::Intrinsic::not_intrinsic;
2069 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
2070 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
2071 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
2073 #define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
2074 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
2075 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
2078 ConvertF32x2ToF8x2Op::getIntrinsicID(
mlir::Type dstTy, NVVM::FPRoundingMode rnd,
2079 NVVM::SaturationMode sat,
bool hasRelu) {
2080 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2081 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
2082 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
2085 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2088 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2091 .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
2092 if (hasRoundingModeRZ)
2094 else if (hasRoundingModeRP)
2097 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2100 llvm_unreachable(
"Invalid conversion in ConvertF32x2ToF8x2Op");
2101 return llvm::Intrinsic::not_intrinsic;
2105 #define GET_F16x2_TO_F8X2_ID(type, has_relu) \
2106 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
2107 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
2112 .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
2115 .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
2119 llvm_unreachable(
"Invalid conversion in ConvertF16x2ToF8x2Op");
2120 return llvm::Intrinsic::not_intrinsic;
2124 #define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
2125 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
2126 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
2129 ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
2130 NVVM::SaturationMode sat) {
2131 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
2133 case NVVM::FPRoundingMode::RZ:
2135 case NVVM::FPRoundingMode::RP:
2138 llvm_unreachable(
"Invalid rounding mode for CvtBF16x2ToF8x2Op");
2143 Tcgen05AllocOp::getIntrinsicIDAndArgs(
Operation &op,
2144 LLVM::ModuleTranslation &mt,
2146 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
2147 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2149 bool isShared = as == NVVMMemorySpace::Shared;
2150 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2154 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
2155 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
2157 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
2158 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
2162 args.push_back(mt.lookupValue(curOp.getAddr()));
2163 args.push_back(mt.lookupValue(curOp.getNCols()));
2169 Operation &op, LLVM::ModuleTranslation &mt,
2171 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
2172 auto id = (curOp.getGroup() == CTAGroupKind::CTA_1)
2173 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
2174 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
2177 args.push_back(mt.lookupValue(curOp.getTaddr()));
2178 args.push_back(mt.lookupValue(curOp.getNCols()));
2183 #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
2184 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
2185 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
2187 #define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
2188 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
2189 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
2192 Tcgen05CommitOp::getIntrinsicIDAndArgs(
Operation &op,
2193 LLVM::ModuleTranslation &mt,
2195 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
2196 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
2198 bool isShared = as == NVVMMemorySpace::Shared;
2199 bool hasMulticast =
static_cast<bool>(curOp.getMulticastMask());
2200 bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2;
2207 args.push_back(mt.lookupValue(curOp.getAddr()));
2209 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
2214 #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
2215 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
2217 #define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
2218 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
2219 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
2221 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
2223 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
2224 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
2225 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
2226 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
2227 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
2231 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
2232 bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2;
2233 auto srcFmt = curOp.getSrcFormat();
2234 auto mc = curOp.getMulticast();
2236 switch (curOp.getShape()) {
2237 case Tcgen05CpShape::SHAPE_128x256b:
2239 case Tcgen05CpShape::SHAPE_128x128b:
2241 case Tcgen05CpShape::SHAPE_4x256b:
2243 case Tcgen05CpShape::SHAPE_32x128b:
2245 case Tcgen05CpShape::SHAPE_64x128b:
2246 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
2250 llvm_unreachable(
"Invalid shape in tcgen05 cp Op");
2257 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
2259 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
2265 LogicalResult result = success();
2266 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2267 result =
emitError(
"shape 16x32bx2 requires offset argument");
2269 auto resTy = getRes().getType();
2270 unsigned resLen = isa<VectorType>(resTy)
2271 ? llvm::cast<VectorType>(resTy).getNumElements()
2274 result =
emitError(llvm::formatv(
"invalid result type length {0} for shape "
2275 "{1} in tcgen05.ld Op",
2276 resLen, stringifyEnum(
getShape())));
2282 LogicalResult result = success();
2283 if (
getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
2284 result =
emitError(
"shape 16x32bx2 requires offset argument");
2286 auto valTy = getVal().getType();
2287 unsigned valLen = isa<VectorType>(valTy)
2288 ? llvm::cast<VectorType>(valTy).getNumElements()
2291 result =
emitError(llvm::formatv(
"invalid input length {0} for shape "
2292 "{1} in tcgen05.st Op",
2293 valLen, stringifyEnum(
getShape())));
2303 if (
auto rangeAttr = op->
getAttrOfType<LLVM::ConstantRangeAttr>(
"range")) {
2304 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
2305 rangeAttr.getLower(), rangeAttr.getUpper()});
2310 llvm::IRBuilderBase &builder) {
2311 return builder.CreateBitCast(arg,
2312 llvm::Type::getInt32Ty(builder.getContext()));
2316 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2317 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
2320 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2321 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2322 args.push_back(mt.lookupValue(curOp.getC()));
2324 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2325 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2326 unsigned type = (isASigned << 1) | isBSigned;
2328 llvm::Intrinsic::nvvm_idp4a_u_u,
2329 llvm::Intrinsic::nvvm_idp4a_u_s,
2330 llvm::Intrinsic::nvvm_idp4a_s_u,
2331 llvm::Intrinsic::nvvm_idp4a_s_s,
2333 return {ids[type], args};
2337 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2338 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
2341 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
2342 args.push_back(
getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
2343 args.push_back(builder.getInt1(curOp.getBHi()));
2344 args.push_back(mt.lookupValue(curOp.getC()));
2346 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
2347 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
2348 unsigned type = (isASigned << 1) | isBSigned;
2350 llvm::Intrinsic::nvvm_idp2a_u_u,
2351 llvm::Intrinsic::nvvm_idp2a_u_s,
2352 llvm::Intrinsic::nvvm_idp2a_s_u,
2353 llvm::Intrinsic::nvvm_idp2a_s_s,
2355 return {ids[type], args};
2359 llvm::IRBuilderBase &builder) {
2360 return builder.CreateAddrSpaceCast(
2363 llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2367 PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2368 LLVM::ModuleTranslation &mt,
2369 llvm::IRBuilderBase &builder) {
2370 using MemSpace = NVVM::NVVMMemorySpace;
2371 using CacheLevel = NVVM::PrefetchCacheLevel;
2373 std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
2374 std::optional<NVVM::CacheEvictionPriority> evictPriority =
2375 op.getEvictPriority();
2376 unsigned addressSpace =
2377 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
2381 llvm::Value *addr = mt.lookupValue(op.getAddr());
2385 if (op.getTensormap())
2386 return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2388 assert(cacheLevel &&
"expected cache level for non-tensormap prefetch");
2390 if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2391 return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
2393 if (evictPriority && *cacheLevel == CacheLevel::L2) {
2394 switch (*evictPriority) {
2395 case NVVM::CacheEvictionPriority::EvictLast:
2396 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
2397 case NVVM::CacheEvictionPriority::EvictNormal:
2398 return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
2400 llvm_unreachable(
"Invalid cache eviction priority");
2404 switch (
static_cast<MemSpace
>(addressSpace)) {
2405 case MemSpace::Generic:
2406 return *cacheLevel == CacheLevel::L1
2409 case MemSpace::Global:
2410 return *cacheLevel == CacheLevel::L1
2412 {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2414 {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
2415 case MemSpace::Local:
2416 return *cacheLevel == CacheLevel::L1
2418 {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2420 {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
2422 llvm_unreachable(
"Invalid pointer address space");
2426 bool NVVM::InlinePtxOp::getAsmValues(
2430 for (
auto arg : getReadWriteArgs())
2432 for (
auto arg : getResults())
2434 for (
auto arg : getReadOnlyArgs())
2441 NVVM::IDArgPair ClusterLaunchControlTryCancelOp::getIntrinsicIDAndArgs(
2442 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2443 auto curOp = cast<NVVM::ClusterLaunchControlTryCancelOp>(op);
2445 args.push_back(mt.lookupValue(curOp.getSmemAddress()));
2446 args.push_back(mt.lookupValue(curOp.getMbarrier()));
2449 curOp.getMulticast()
2451 nvvm_clusterlaunchcontrol_try_cancel_async_multicast_shared
2452 : llvm::Intrinsic::nvvm_clusterlaunchcontrol_try_cancel_async_shared;
2454 return {intrinsicID, args};
2457 NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
2458 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
2459 auto curOp = cast<NVVM::ClusterLaunchControlQueryCancelOp>(op);
2461 args.push_back(mt.lookupValue(curOp.getTryCancelResponse()));
2465 switch (curOp.getQueryType()) {
2466 case NVVM::ClusterLaunchControlQueryType::IS_CANCELED:
2468 llvm::Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled;
2470 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_X:
2471 intrinsicID = llvm::Intrinsic::
2472 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_x;
2474 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Y:
2475 intrinsicID = llvm::Intrinsic::
2476 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y;
2478 case NVVM::ClusterLaunchControlQueryType::GET_FIRST_CTA_ID_Z:
2479 intrinsicID = llvm::Intrinsic::
2480 nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z;
2483 return {intrinsicID, args};
2491 void NVVMDialect::initialize() {
2494 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2497 #define GET_ATTRDEF_LIST
2498 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
2503 allowUnknownOperations();
2504 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
2505 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
2508 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
2510 StringAttr attrName = attr.
getName();
2512 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
2513 if (!isa<LLVM::LLVMFuncOp>(op)) {
2514 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
2515 <<
"' attribute attached to unexpected op";
2520 if (attrName == NVVMDialect::getMaxntidAttrName() ||
2521 attrName == NVVMDialect::getReqntidAttrName() ||
2522 attrName == NVVMDialect::getClusterDimAttrName()) {
2523 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
2524 if (!values || values.empty() || values.size() > 3) {
2527 <<
"' attribute must be integer array with maximum 3 index";
2532 if (attrName == NVVMDialect::getMinctasmAttrName() ||
2533 attrName == NVVMDialect::getMaxnregAttrName() ||
2534 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
2535 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue())) {
2537 <<
"'" << attrName <<
"' attribute must be integer constant";
2541 if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) {
2542 if (!op->
hasAttr(NVVMDialect::getReqntidAttrName()) ||
2543 !op->
hasAttr(NVVMDialect::getClusterDimAttrName())) {
2545 <<
"'" << attrName <<
"' attribute must be used along with "
2546 <<
"'" << NVVMDialect::getReqntidAttrName() <<
"' and "
2547 <<
"'" << NVVMDialect::getClusterDimAttrName() <<
"'";
2554 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
2555 unsigned regionIndex,
2558 auto funcOp = dyn_cast<FunctionOpInterface>(op);
2562 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
2563 StringAttr attrName = argAttr.
getName();
2564 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
2568 <<
"' attribute must be present only on kernel arguments";
2570 if (!isa<UnitAttr>(argAttr.
getValue()))
2571 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
2572 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
2575 <<
"' attribute requires the argument to also have attribute '"
2576 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
2587 unsigned NVVMMemorySpaceAttr::getAddressSpace()
const {
2588 return static_cast<unsigned>(getValue());
2591 bool NVVMMemorySpaceAttr::isValidLoad(
2592 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2593 const ::mlir::DataLayout *dataLayout,
2599 bool NVVMMemorySpaceAttr::isValidStore(
2600 Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment,
2601 const ::mlir::DataLayout *dataLayout,
2607 bool NVVMMemorySpaceAttr::isValidAtomicOp(
2608 ptr::AtomicBinOp op,
Type type, ptr::AtomicOrdering ordering,
2609 std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout,
2612 assert(
false &&
"unimplemented, see TODO in the source.");
2616 bool NVVMMemorySpaceAttr::isValidAtomicXchg(
2617 Type type, ptr::AtomicOrdering successOrdering,
2618 ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment,
2619 const ::mlir::DataLayout *dataLayout,
2622 assert(
false &&
"unimplemented, see TODO in the source.");
2626 bool NVVMMemorySpaceAttr::isValidAddrSpaceCast(
2630 assert(
false &&
"unimplemented, see TODO in the source.");
2634 bool NVVMMemorySpaceAttr::isValidPtrIntCast(
2639 assert(
false &&
"unimplemented, see TODO in the source.");
2648 int optLevel, StringRef triple, StringRef chip,
2649 StringRef features, DictionaryAttr flags,
2650 ArrayAttr files,
bool verifyTarget) {
2651 if (optLevel < 0 || optLevel > 3) {
2652 emitError() <<
"The optimization level must be a number between 0 and 3.";
2655 if (triple.empty()) {
2656 emitError() <<
"The target triple cannot be empty.";
2660 emitError() <<
"The target chip cannot be empty.";
2664 return mlir::isa_and_nonnull<StringAttr>(attr);
2666 emitError() <<
"All the elements in the `link` array must be strings.";
2672 LogicalResult NVVMTargetAttr::verifyTarget(
Operation *gpuModule) {
2673 if (!getVerifyTarget())
2676 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
2679 "NVVM target attribute must be attached to a GPU module");
2683 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
2686 "Minimum NVVM target SM version is sm_20");
2690 if (
auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
2693 op->emitOpError() <<
"is not supported on " << getChip();
2694 return WalkResult::interrupt();
2703 #define GET_OP_CLASSES
2704 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
2706 #define GET_ATTRDEF_CLASSES
2707 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static llvm::Value * getParamCastedAddr(llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
static llvm::Value * packValInto64Bits(llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
Packs the given field into the result.
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
static FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
static LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape, unsigned vecLen)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
static void nvvmInferResultRanges(Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static constexpr unsigned notIntrinsic
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, std::optional< int64_t > alignment, const ::mlir::DataLayout *dataLayout, function_ref< InFlightDiagnostic()> emitError)
Checks whether the given type is an LLVM type that can be loaded or stored.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
@ Write
Write register with '=' modifier.
@ ReadWrite
ReadWrite register with '+' modifier.
@ Read
Read register with no modifier.
std::pair< llvm::Intrinsic::ID, llvm::SmallVector< llvm::Value * > > IDArgPair
A pair type of LLVM's Intrinsic ID and args (which are llvm values).
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isMinimumSMVersion() const
bool isCompatibleWith(const NVVMCheckSMVersion &targetSM) const
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)