28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
52 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
62 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
70 if (a.getElementType() == f16x2Ty) {
74 if (a.getElementType() == i32x2Ty) {
79 if (a.getElementType() == f64x2Ty) {
82 if (a.getElementType() == f32x2Ty) {
91 return vectorResultType;
103 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
114 auto makeConst = [&](int32_t index) ->
Value {
124 if (arrayType.getElementType() == f16x2Ty ||
125 arrayType.getElementType() == f32x1Ty) {
126 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
128 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130 loc, arrayType.getElementType(), el);
131 elements.push_back(el);
139 if (arrayType.getElementType() == i32x2Ty ||
140 arrayType.getElementType() == f64x2Ty ||
141 arrayType.getElementType() == f32x2Ty) {
143 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145 rewriter.
create<LLVM::UndefOp>(loc, arrayType.getElementType());
147 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
148 Value x2 = rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
152 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
154 elements.push_back(vec);
159 Value result = rewriter.
create<LLVM::UndefOp>(loc, arrayType);
161 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, el.value(),
167 return intrinsicResult;
177 NVVM::MMATypes operandPtxType) {
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
188 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = b.
create<LLVM::ExtractValueOp>(operand, i);
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == f32x1Ty &&
196 operandPtxType == NVVM::MMATypes::tf32)) {
197 result.push_back(b.
create<LLVM::BitcastOp>(i32Ty, toUse));
204 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206 innerArrayTy.getElementType() == f64Ty ||
207 innerArrayTy.getElementType() == f32Ty)) {
208 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209 idx < innerSize; idx++) {
210 result.push_back(b.
create<LLVM::ExtractElementOp>(
216 result.push_back(toUse);
223 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224 barrierType.getMemorySpace()));
229 nvgpu::MBarrierGroupType barrierType) {
234 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
242 nvgpu::MBarrierGroupType barrierType) {
244 MemRefLayoutAttrInterface layout;
255 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
266 auto vectorResultType = dyn_cast<VectorType>(op->
getResultTypes()[0]);
267 if (!vectorResultType) {
271 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 Type ldMatrixResultType;
276 if (num32BitRegs > 1) {
283 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
286 adaptor.getIndices(), rewriter);
287 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
288 ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row);
297 Type finalResultType = typeConverter->convertType(vectorResultType);
298 Value result = b.create<LLVM::UndefOp>(finalResultType);
299 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
304 result = b.create<LLVM::InsertValueOp>(result, casted, i);
317 return NVVM::MMATypes::s8;
319 return NVVM::MMATypes::s4;
321 return NVVM::MMATypes::f16;
323 return NVVM::MMATypes::f64;
325 return NVVM::MMATypes::tf32;
333 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
338 VectorType aType = op.getMatrixA().getType();
339 VectorType bType = op.getMatrixA().getType();
340 VectorType cType = op.getMatrixC().getType();
342 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
345 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
346 if (aType.getElementType().isF32() && !tf32Enabled)
351 return op->
emitOpError(
"failed to deduce operand PTX types");
354 return op->
emitOpError(
"failed to deduce operand PTX types");
355 std::optional<NVVM::MMATypes> ptxTypeC =
356 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 "could not infer the PTX type for the accumulator/result");
363 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364 if (isa<IntegerType>(aType.getElementType()))
365 overflow = NVVM::MMAIntOverflow::satfinite;
377 Value intrinsicResult = b.create<NVVM::MmaOp>(
378 intrinsicResTy, matA, matB, matC,
383 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
386 NVVM::MMALayout::col});
388 desiredRetTy, intrinsicResult,
394 struct ConvertNVGPUToNVVMPass
395 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400 arith::ArithDialect>();
403 void runOnOperation()
override {
411 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
414 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
416 int64_t sizeM = type.getFragmented().getDimSize(0);
417 int64_t sizeN = type.getFragmented().getDimSize(1);
421 numMembers = sizeN / 2;
422 else if (elemType.
isF16())
423 numMembers = sizeN / 4;
425 llvm_unreachable(
"unsupported type for warpgroup accumulator");
428 for (
unsigned i = 0; i < numMembers; i++)
429 innerStructBody.push_back(elemType);
430 auto innerStructType =
435 structBody.push_back(innerStructType);
439 return converter.convertType(convertedType);
441 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
444 converter.addConversion(
445 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
448 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
449 return converter.convertType(
452 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
457 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
458 target.addLegalDialect<::mlir::arith::ArithDialect>();
459 target.addLegalDialect<::mlir::memref::MemRefDialect>();
460 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
462 converter, patterns, target);
464 std::move(patterns))))
470 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
474 llvm::raw_string_ostream ss(str);
475 for (
unsigned i = 0; i < matCSize; i++)
477 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
490 static std::string buildMmaSparseAsmString(
491 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
492 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
493 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
494 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
495 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
496 return NVVM::stringifyMMATypes(ptxType);
500 llvm::raw_string_ostream ss(asmStr);
501 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
502 << shape[2] <<
".row.col.";
505 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
507 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
508 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
509 unsigned asmArgIdx = 0;
513 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
515 for (
unsigned i = 0; i < arrSize; i++)
516 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
519 ss <<
"$" << asmArgIdx++ <<
",";
520 assert(metaDataSelector <= 1);
521 ss <<
"0x" << metaDataSelector <<
";";
530 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
531 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
533 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
534 Type intrinsicResultType) {
535 auto asmDialectAttr =
538 const unsigned matASize = unpackedAData.size();
539 const unsigned matBSize = unpackedB.size();
540 const unsigned matCSize = unpackedC.size();
542 std::string asmStr = buildMmaSparseAsmString(
543 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
544 ptxTypeD, overflow, metadataSelector);
545 std::string constraintStr =
546 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
549 asmVals.reserve(matASize + matBSize + matCSize + 1);
551 llvm::append_range(asmVals, args);
552 asmVals.push_back(indexData);
554 return b.
create<LLVM::InlineAsmOp>(
566 struct NVGPUMmaSparseSyncLowering
571 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
576 VectorType aType = op.getMatrixA().getType();
577 VectorType bType = op.getMatrixB().getType();
578 VectorType cType = op.getMatrixC().getType();
582 return op->
emitOpError(
"failed to deduce operand PTX types");
585 return op->
emitOpError(
"failed to deduce operand PTX types");
586 std::optional<NVVM::MMATypes> ptxTypeC =
587 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
591 "could not infer the PTX type for the accumulator/result");
594 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
595 if (aType.getElementType().isF32() && !tf32Enabled)
599 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
600 if (isa<IntegerType>(aType.getElementType()))
601 overflow = NVVM::MMAIntOverflow::satfinite;
615 Value sparseMetadata = adaptor.getSparseMetadata();
616 if (sparseMetadata.
getType() !=
618 return op->
emitOpError() <<
"Expected metadata type to be LLVM "
619 "VectorType of 2 i16 elements";
624 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
625 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
627 if (
failed(intrinsicResult))
630 assert((*intrinsicResult).getNumResults() == 1 &&
631 "expected inline asm op returns a single LLVM struct type");
634 (*intrinsicResult)->getResult(0), rewriter));
639 struct NVGPUAsyncCopyLowering
645 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
649 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
651 getStridedElementPtr(b.
getLoc(), dstMemrefType, adaptor.getDst(),
652 adaptor.getDstIndices(), rewriter);
654 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
655 if (
failed(dstAddressSpace))
657 loc,
"destination memref address space not convertible to integer");
659 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
661 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
662 if (
failed(srcAddressSpace))
664 loc,
"source memref address space not convertible to integer");
666 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
667 adaptor.getSrcIndices(), rewriter);
671 scrPtr = b.
create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
672 int64_t dstElements = adaptor.getDstElements().getZExtValue();
673 int64_t sizeInBytes =
674 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
679 Value srcBytes = adaptor.getSrcElements();
691 srcBytes = b.
create<LLVM::LShrOp>(
692 b.
create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
696 NVVM::LoadCacheModifierKind cacheModifier =
697 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
698 ? NVVM::LoadCacheModifierKind::CG
699 : NVVM::LoadCacheModifierKind::CA;
701 b.
create<NVVM::CpAsyncOp>(
714 struct NVGPUAsyncCreateGroupLowering
720 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
722 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.
getLoc());
732 struct NVGPUAsyncWaitLowering
738 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
741 int32_t numGroups = adaptor.getNumGroups().value_or(0);
742 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.
getLoc(), numGroups);
749 struct NVGPUMBarrierCreateLowering
753 template <
typename moduleT>
756 MemRefType barrierType)
const {
760 auto global = rewriter.
create<memref::GlobalOp>(
761 funcOp->
getLoc(),
"__mbarrier",
767 symbolTable.insert(global);
772 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
776 rewriter.
getContext(), op.getBarriers().getType());
778 memref::GlobalOp global;
780 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
782 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
792 template <
typename SourceOp>
798 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
801 MemRefType mbarrierMemrefType =
804 b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
809 struct NVGPUMBarrierInitLowering
810 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
811 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
814 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
817 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
819 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
820 adaptor.getMbarId(), rewriter);
824 op, barrier, count, adaptor.getPredicate());
827 adaptor.getPredicate());
834 struct NVGPUMBarrierArriveLowering
835 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
836 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
838 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
842 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
843 adaptor.getMbarId(), rewriter);
844 Type tokenType = getTypeConverter()->convertType(
859 struct NVGPUMBarrierArriveNoCompleteLowering
860 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
861 using MBarrierBasePattern<
862 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
864 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
868 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
869 adaptor.getMbarId(), rewriter);
870 Type tokenType = getTypeConverter()->convertType(
875 op, tokenType, barrier, count);
878 op, tokenType, barrier, count);
885 struct NVGPUMBarrierTestWaitLowering
886 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
887 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
889 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
893 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
894 adaptor.getMbarId(), rewriter);
898 op, retType, barrier, adaptor.getToken());
901 op, retType, barrier, adaptor.getToken());
907 struct NVGPUMBarrierArriveExpectTxLowering
908 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
909 using MBarrierBasePattern<
910 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
912 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
916 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
917 adaptor.getMbarId(), rewriter);
922 op, barrier, txcount, adaptor.getPredicate());
927 op, barrier, txcount, adaptor.getPredicate());
932 struct NVGPUMBarrierTryWaitParityLowering
933 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
934 using MBarrierBasePattern<
935 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
937 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
941 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
942 adaptor.getMbarId(), rewriter);
948 op, barrier, phase, ticks);
958 struct NVGPUTmaAsyncLoadOpLowering
959 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
960 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
962 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
965 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
966 Value dest = getStridedElementPtr(op->
getLoc(), srcMemrefType,
967 adaptor.getDst(), {}, rewriter);
969 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
970 adaptor.getMbarId(), rewriter);
977 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
982 struct NVGPUGenerateWarpgroupDescriptorLowering
988 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
993 nvgpu::TensorMapSwizzleKind swizzleKind =
994 op.getTensorMap().getType().getSwizzle();
997 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
998 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
999 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1002 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1003 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1004 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1008 auto makeConst = [&](uint64_t index) ->
Value {
1011 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1012 return b.
create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1014 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1015 return b.
create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1017 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1018 return b.
create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1021 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1022 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1023 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1024 uint64_t offsetVal = 0;
1026 Value strideDim = makeConst(strideDimVal);
1027 Value leadDim = makeConst(leadDimVal);
1029 Value baseAddr = getStridedElementPtr(
1030 op->
getLoc(), cast<MemRefType>(op.getTensor().getType()),
1031 adaptor.getTensor(), {}, rewriter);
1032 Value basePtr = b.
create<LLVM::PtrToIntOp>(ti64, baseAddr);
1034 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1036 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1037 startLeadBit = 16, startBaseAddrBit = 0;
1038 Value dsc = makeConst(0);
1040 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1042 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1044 dsc = insertBit(dsc, strideDim, startStrideBit);
1046 dsc = insertBit(dsc, leadDim, startLeadBit);
1048 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1050 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1051 <<
"leading_off:" << leadDimVal <<
"\t"
1052 <<
"stride_off :" << strideDimVal <<
"\t"
1053 <<
"base_offset:" << offsetVal <<
"\t"
1054 <<
"layout_type:" << swizzle <<
" ("
1055 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1056 <<
")\n start_addr : " << baseAddr <<
"\n");
1072 enum CUtensorMapDataTypeEnum {
1073 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1074 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1075 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1076 CU_TENSOR_MAP_DATA_TYPE_INT32,
1077 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1078 CU_TENSOR_MAP_DATA_TYPE_INT64,
1079 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1080 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1081 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1082 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1083 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1084 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1085 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1089 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1091 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1093 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1095 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1097 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1099 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1101 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1103 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1105 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1107 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1109 llvm_unreachable(
"Not supported data type");
1112 struct NVGPUTmaCreateDescriptorOpLowering
1117 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1123 Value tensorElementType =
1124 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1125 auto promotedOperands = getTypeConverter()->promoteOperands(
1128 Value boxArrayPtr = b.
create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1129 makeI64Const(b, 5));
1130 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1131 Value gep = b.
create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1132 boxArrayPtr, makeI64Const(b, index));
1133 b.
create<LLVM::StoreOp>(value, gep);
1136 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1139 arguments.push_back(promotedOperands[0]);
1140 arguments.push_back(promotedOperands[1]);
1141 arguments.push_back(tensorElementType);
1142 arguments.push_back(
1143 makeI64Const(b, (
int)desc.getInterleave()));
1144 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1145 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1146 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1147 arguments.push_back(boxArrayPtr);
1161 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1163 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1170 struct NVGPUWarpgroupMmaOpLowering
1194 class WarpgroupGemm {
1195 nvgpu::WarpgroupMmaOp op;
1200 int64_t totalM, totalN, totalK;
1203 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1206 int iterationM = 0, iterationN = 0, iterationK = 0;
1211 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1214 if (inputElemType.
isTF32()) {
1216 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1221 }
else if (inputElemType.
isInteger(1)) {
1224 llvm_unreachable(
"msg: not supported K shape");
1226 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1227 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1231 NVVM::WGMMATypesAttr generateWgmmaType(
Type type)
const {
1232 auto getWgmmaType = [](
Type elemType) {
1234 return NVVM::WGMMATypes::tf32;
1235 if (elemType.
isF16())
1236 return NVVM::WGMMATypes::f16;
1238 return NVVM::WGMMATypes::bf16;
1240 return NVVM::WGMMATypes::e4m3;
1242 return NVVM::WGMMATypes::e5m2;
1244 return NVVM::WGMMATypes::b1;
1246 return NVVM::WGMMATypes::s8;
1248 return NVVM::WGMMATypes::u8;
1249 llvm_unreachable(
"unsupported type");
1256 generateWgmmaLayout(std::optional<bool> transpose)
const {
1257 if (transpose.value_or(
false))
1263 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1268 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1270 NVVM::WGMMAScaleOut::one);
1273 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1275 NVVM::WGMMAScaleIn::one);
1301 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1302 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1303 Type elemA = matrixTypeA.getElementType();
1305 int tileShapeA = matrixTypeA.getDimSize(1);
1306 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1308 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1309 <<
"] [wgmma descriptors] Descriptor A + "
1310 << incrementVal <<
" | \t ");
1313 return makeAdd(desc, makeI64Const(b, incrementVal));
1327 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1328 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1329 Type elemB = matrixTypeB.getElementType();
1331 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1333 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1336 return makeAdd(desc, makeI64Const(b, incrementVal));
1341 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1342 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1343 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1344 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1345 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1346 << (iterationK * wgmmaK) <<
":"
1347 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1348 <<
" B[" << (iterationK * wgmmaK) <<
":"
1349 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1350 << wgmmaN <<
"])\n");
1352 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1353 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1355 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1356 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1358 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1359 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1361 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1362 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1363 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1364 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1365 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
1368 op->
getContext(), NVVM::MMAIntOverflow::wrapped);
1370 return b.
create<NVVM::WgmmaMmaAsyncOp>(
1371 matrixC.
getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1372 itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1376 Value generateWgmmaGroup() {
1378 b.
create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1382 for (
int i = 0; i < iterationM; ++i) {
1383 Value matrixC = b.
create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1384 for (
int j = 0;
j < iterationN; ++
j)
1385 for (
int k = 0; k < iterationK; ++k)
1386 matrixC = generateWgmma(i,
j, k, matrixC);
1387 wgmmaResults.push_back(matrixC);
1390 wgmmaResult = b.
create<LLVM::InsertValueOp>(wgmmaResult.
getType(),
1391 wgmmaResult, matrix, idx);
1399 : op(op), b(b), adaptor(adaptor) {
1401 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1402 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1403 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1404 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1405 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1406 << totalK <<
"][" << totalN <<
"] ---===\n");
1411 op.getDescriptorA().getType().getTensor().getElementType());
1414 iterationM = totalM / wgmmaM;
1415 iterationN = totalN / wgmmaN;
1416 iterationK = totalK / wgmmaK;
1424 Value generateWarpgroupMma() {
1425 b.
create<NVVM::WgmmaFenceAlignedOp>();
1426 Value wgmmaResult = generateWgmmaGroup();
1427 b.
create<NVVM::WgmmaGroupSyncAlignedOp>();
1428 b.
create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1433 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1438 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1441 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1449 struct NVGPUWarpgroupMmaStoreOpLowering
1495 auto makeConst = [&](int32_t index) ->
Value {
1498 Value c1 = makeConst(1);
1499 Value c2 = makeConst(2);
1500 Value c4 = makeConst(4);
1501 Value c8 = makeConst(8);
1502 Value c16 = makeConst(16);
1513 Value laneId = b.
create<LLVM::URemOp>(i32, tidx, warpSize);
1514 Value warpId = b.
create<LLVM::UDivOp>(i32, tidx, warpSize);
1515 Value lane4Id = b.
create<LLVM::UDivOp>(i32, laneId, c4);
1516 Value lane4modId = b.
create<LLVM::URemOp>(i32, laneId, c4);
1518 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1523 Value idy1 = b.
create<arith::IndexCastOp>(it, makeAdd(y, c1));
1524 Value d0 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i);
1525 Value d1 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1530 Value tj = makeMul(lane4modId, c2);
1531 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1533 ti = makeAdd(ti, makeConst(offset));
1534 for (
int i = 0; i < 2; ++i) {
1535 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1536 for (
int j = 0;
j < 16; ++
j) {
1537 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1538 int sIndex = i * 2 +
j * 4;
1539 makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
1545 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1549 Value matriDValue = adaptor.getMatrixD();
1553 Value innerStructValue = b.
create<LLVM::ExtractValueOp>(matriDValue, idx);
1554 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1555 offset += structType.getBody().size();
1562 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1567 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1572 ->convertType(op.getMatrixC().getType())
1580 Value packStruct = b.
create<LLVM::UndefOp>(packStructType);
1585 Value structValue = b.
create<LLVM::ExtractValueOp>(packStruct, idx);
1586 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1587 structValue = b.
create<LLVM::InsertValueOp>(
1590 innerStructs.push_back(structValue);
1594 packStruct = b.
create<LLVM::InsertValueOp>(packStruct.getType(),
1595 packStruct, matrix, idx);
1602 struct NVGPUTmaPrefetchOpLowering
1606 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1609 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1619 NVGPUMBarrierCreateLowering,
1620 NVGPUMBarrierInitLowering,
1621 NVGPUMBarrierArriveLowering,
1622 NVGPUMBarrierArriveNoCompleteLowering,
1623 NVGPUMBarrierTestWaitLowering,
1624 NVGPUMBarrierTryWaitParityLowering,
1625 NVGPUTmaAsyncLoadOpLowering,
1626 NVGPUTmaCreateDescriptorOpLowering,
1627 NVGPUTmaPrefetchOpLowering,
1628 NVGPUMBarrierArriveExpectTxLowering,
1629 NVGPUGenerateWarpgroupDescriptorLowering,
1630 NVGPUWarpgroupMmaOpLowering,
1631 NVGPUWarpgroupMmaStoreOpLowering,
1632 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1633 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1634 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1635 NVGPUMmaSparseSyncLowering>(converter);
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isFloat8E4M3FN() const
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isFloat8E5M2() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.