28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
34#define DEBUG_TYPE "nvgpu-to-nvvm"
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(
b,
b.getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
113 auto makeConst = [&](int32_t
index) ->
Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (
const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc,
result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty =
b.getI32Type();
179 Type f64Ty =
b.getF64Type();
180 Type f32Ty =
b.getF32Type();
181 Type i64Ty =
b.getI64Type();
182 Type i8x4Ty = VectorType::get(4,
b.getI8Type());
183 Type i4x8Ty = VectorType::get(8,
b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
187 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(
b, operand, i);
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(
b, i32Ty, toUse));
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
211 LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(idx))));
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
228 nvgpu::MBarrierGroupType barrierType) {
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
241 nvgpu::MBarrierGroupType barrierType) {
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter)
const override {
257 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279 ldMatrixResultType = rewriter.getI32Type();
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 shape, NVVM::LdStMatrixEltType::B16);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value
result = LLVM::PoisonOp::create(
b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(
b, ldMatrixResult, i)
304 Value casted = LLVM::BitcastOp::create(
b, innerVectorType, i32Register);
308 rewriter.replaceOp(op,
result);
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override {
336 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
368 SmallVector<Value> matA =
370 SmallVector<Value> matB =
372 SmallVector<Value> matC =
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(
b, intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395struct ConvertNVGPUToNVVMPass
399 void runOnOperation()
override {
405 converter, [](gpu::AddressSpace space) ->
unsigned {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409 case gpu::AddressSpace::Workgroup:
410 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411 case gpu::AddressSpace::Private:
414 llvm_unreachable(
"unknown address space enum value");
415 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
420 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
421 return converter.convertType(IntegerType::get(type.getContext(), 32));
423 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
424 Type elemType = type.getFragmented().getElementType();
425 int64_t sizeM = type.getFragmented().getDimSize(0);
426 int64_t sizeN = type.getFragmented().getDimSize(1);
430 numMembers = sizeN / 2;
431 else if (elemType.
isF16())
432 numMembers = sizeN / 4;
434 llvm_unreachable(
"unsupported type for warpgroup accumulator");
436 SmallVector<Type> innerStructBody;
437 for (
unsigned i = 0; i < numMembers; i++)
438 innerStructBody.push_back(elemType);
439 auto innerStructType =
440 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
442 SmallVector<Type> structBody;
444 structBody.push_back(innerStructType);
447 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448 return converter.convertType(convertedType);
450 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
451 return converter.convertType(IntegerType::get(type.getContext(), 64));
453 converter.addConversion(
454 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
455 return converter.convertType(IntegerType::get(type.getContext(), 64));
457 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
458 return converter.convertType(
461 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
462 return LLVM::LLVMPointerType::get(type.getContext());
466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467 target.addLegalDialect<::mlir::arith::ArithDialect>();
468 target.addLegalDialect<::mlir::memref::MemRefDialect>();
469 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
472 if (
failed(applyPartialConversion(getOperation(),
target,
479static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
483 llvm::raw_string_ostream ss(str);
484 for (
unsigned i = 0; i < matCSize; i++)
486 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
498static std::string buildMmaSparseAsmString(
499 const std::array<int64_t, 3> &
shape,
unsigned matASize,
unsigned matBSize,
500 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
503 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504 return NVVM::stringifyMMATypes(ptxType);
508 llvm::raw_string_ostream ss(asmStr);
509 ss <<
"mma.sp.sync.aligned.m" <<
shape[0] <<
"n" <<
shape[1] <<
"k"
510 <<
shape[2] <<
".row.col.";
513 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
515 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
516 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
517 unsigned asmArgIdx = 0;
521 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
523 for (
unsigned i = 0; i < arrSize; i++)
524 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
527 ss <<
"$" << asmArgIdx++ <<
",";
528 assert(metaDataSelector <= 1);
529 ss <<
"0x" << metaDataSelector <<
";";
535static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
537 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
540 int64_t metadataSelector,
const std::array<int64_t, 3> &
shape,
541 Type intrinsicResultType) {
542 auto asmDialectAttr =
543 LLVM::AsmDialectAttr::get(
b.getContext(), LLVM::AsmDialect::AD_ATT);
545 const unsigned matASize = unpackedAData.size();
546 const unsigned matBSize = unpackedB.size();
547 const unsigned matCSize = unpackedC.size();
549 std::string asmStr = buildMmaSparseAsmString(
550 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551 ptxTypeD, overflow, metadataSelector);
552 std::string constraintStr =
553 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
556 asmVals.reserve(matASize + matBSize + matCSize + 1);
558 llvm::append_range(asmVals, args);
559 asmVals.push_back(indexData);
561 return LLVM::InlineAsmOp::create(
b,
568 LLVM::TailCallKind::None,
574struct NVGPUMmaSparseSyncLowering
576 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
579 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
580 ConversionPatternRewriter &rewriter)
const override {
581 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
584 VectorType aType = op.getMatrixA().getType();
585 VectorType bType = op.getMatrixB().getType();
586 VectorType cType = op.getMatrixC().getType();
588 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
590 return op->emitOpError(
"failed to deduce operand PTX types");
591 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
593 return op->emitOpError(
"failed to deduce operand PTX types");
594 std::optional<NVVM::MMATypes> ptxTypeC =
595 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
598 return op->emitError(
599 "could not infer the PTX type for the accumulator/result");
602 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603 if (aType.getElementType().isF32() && !tf32Enabled)
607 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608 if (isa<IntegerType>(aType.getElementType()))
609 overflow = NVVM::MMAIntOverflow::satfinite;
611 SmallVector<Value> matA =
613 SmallVector<Value> matB =
615 SmallVector<Value> matC =
618 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
620 typeConverter->convertType(op->getResultTypes()[0]));
623 Value sparseMetadata = adaptor.getSparseMetadata();
624 if (sparseMetadata.
getType() != VectorType::get(2, rewriter.getI16Type()))
625 return op->emitOpError() <<
"Expected metadata type to be LLVM "
626 "VectorType of 2 i16 elements";
628 LLVM::BitcastOp::create(
b, rewriter.getI32Type(), sparseMetadata);
630 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
634 if (
failed(intrinsicResult))
637 assert((*intrinsicResult).getNumResults() == 1 &&
638 "expected inline asm op returns a single LLVM struct type");
641 (*intrinsicResult)->getResult(0), rewriter));
646struct NVGPUAsyncCopyLowering
648 using ConvertOpToLLVMPattern<
649 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
652 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter)
const override {
654 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
655 Location loc = op.getLoc();
656 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
659 adaptor.getDst(), adaptor.getDstIndices());
660 FailureOr<unsigned> dstAddressSpace =
661 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662 if (
failed(dstAddressSpace))
663 return rewriter.notifyMatchFailure(
664 loc,
"destination memref address space not convertible to integer");
666 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667 FailureOr<unsigned> srcAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669 if (
failed(srcAddressSpace))
670 return rewriter.notifyMatchFailure(
671 loc,
"source memref address space not convertible to integer");
675 adaptor.getSrcIndices());
677 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
678 op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679 scrPtr = LLVM::AddrSpaceCastOp::create(
b, srcPointerGlobalType, scrPtr);
680 int64_t dstElements = adaptor.getDstElements().getZExtValue();
681 int64_t sizeInBytes =
682 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
687 Value srcBytes = adaptor.getSrcElements();
694 LLVM::ConstantOp::create(
b,
b.getI32Type(),
b.getI32IntegerAttr(3));
695 Value bitwidth = LLVM::ConstantOp::create(
697 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
698 Value srcElementsI32 = LLVM::TruncOp::create(
b,
b.getI32Type(), srcBytes);
699 srcBytes = LLVM::LShrOp::create(
700 b, LLVM::MulOp::create(
b, bitwidth, srcElementsI32), c3I32);
704 NVVM::LoadCacheModifierKind cacheModifier =
705 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
706 ? NVVM::LoadCacheModifierKind::CG
707 : NVVM::LoadCacheModifierKind::CA;
709 NVVM::CpAsyncOp::create(
710 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
711 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
716 LLVM::ConstantOp::create(
b, IntegerType::get(op.getContext(), 32),
717 rewriter.getI32IntegerAttr(0));
718 rewriter.replaceOp(op, zero);
723struct NVGPUAsyncCreateGroupLowering
725 using ConvertOpToLLVMPattern<
726 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
729 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter)
const override {
731 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
733 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
734 IntegerType::get(op.getContext(), 32),
735 rewriter.getI32IntegerAttr(0));
736 rewriter.replaceOp(op, zero);
741struct NVGPUAsyncWaitLowering
743 using ConvertOpToLLVMPattern<
744 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
747 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter)
const override {
750 int32_t numGroups = adaptor.getNumGroups().value_or(0);
751 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
752 rewriter.eraseOp(op);
758struct NVGPUMBarrierCreateLowering
760 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
762 template <
typename moduleT>
763 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
764 Operation *funcOp, moduleT moduleOp,
765 MemRefType barrierType)
const {
766 SymbolTable symbolTable(moduleOp);
767 OpBuilder::InsertionGuard guard(rewriter);
768 rewriter.setInsertionPoint(&moduleOp.front());
769 auto global = memref::GlobalOp::create(
770 rewriter, funcOp->
getLoc(),
"__mbarrier",
771 rewriter.getStringAttr(
"private"),
775 rewriter.getI64IntegerAttr(8));
776 symbolTable.insert(global);
781 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter)
const override {
785 rewriter.getContext(), op.getBarriers().getType());
787 memref::GlobalOp global;
789 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
793 rewriter.setInsertionPoint(op);
794 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
801template <
typename SourceOp>
804 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
806 Value getMbarrierPtr(ImplicitLocOpBuilder &
b,
807 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
809 ConversionPatternRewriter &rewriter)
const {
810 MemRefType mbarrierMemrefType =
813 rewriter,
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
817struct NVGPUMBarrierGetLowering
818 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
822 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
823 ConversionPatternRewriter &rewriter)
const override {
824 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
825 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
826 rewriter.setInsertionPoint(op);
827 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
828 adaptor.getMbarId(), rewriter);
829 Type resType = op.getMbarrierPointer().getType();
830 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
836struct NVGPUMBarrierInitLowering
837 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
841 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
842 ConversionPatternRewriter &rewriter)
const override {
843 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
844 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
845 rewriter.setInsertionPoint(op);
846 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
847 adaptor.getMbarId(), rewriter);
849 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
850 adaptor.getPredicate());
856struct NVGPUMBarrierArriveLowering
857 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
858 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
860 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter)
const override {
862 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
864 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
865 adaptor.getMbarId(), rewriter);
866 Type tokenType = getTypeConverter()->convertType(
867 nvgpu::MBarrierTokenType::get(op->getContext()));
868 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
875struct NVGPUMBarrierArriveNoCompleteLowering
876 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877 using MBarrierBasePattern<
878 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
880 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const override {
882 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
884 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
885 adaptor.getMbarId(), rewriter);
886 Type tokenType = getTypeConverter()->convertType(
887 nvgpu::MBarrierTokenType::get(op->getContext()));
889 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
890 op, tokenType, barrier, count);
896struct NVGPUMBarrierTestWaitLowering
897 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
898 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
900 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter)
const override {
902 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
904 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
905 adaptor.getMbarId(), rewriter);
906 Type retType = rewriter.getI1Type();
907 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
913struct NVGPUMBarrierArriveExpectTxLowering
914 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
915 using MBarrierBasePattern<
916 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
918 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
920 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
922 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
923 adaptor.getMbarId(), rewriter);
924 Value txcount =
truncToI32(
b, adaptor.getTxcount());
925 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
926 op, barrier, txcount, adaptor.getPredicate());
931struct NVGPUMBarrierTryWaitParityLowering
932 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
933 using MBarrierBasePattern<
934 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
936 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
937 ConversionPatternRewriter &rewriter)
const override {
938 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
940 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
941 adaptor.getMbarId(), rewriter);
944 LLVM::ZExtOp::create(
b,
b.getI32Type(), adaptor.getPhaseParity());
945 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
951struct NVGPUTmaAsyncLoadOpLowering
952 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
953 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
955 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter)
const override {
957 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
958 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
960 adaptor.getDst(), {});
964 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
966 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
967 dest = LLVM::AddrSpaceCastOp::create(
b, ptrSharedClusterType, dest);
970 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
971 adaptor.getMbarId(), rewriter);
973 SmallVector<Value> coords = adaptor.getCoordinates();
974 for (
auto [index, value] : llvm::enumerate(coords)) {
979 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
980 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
981 ValueRange{}, adaptor.getMulticastMask(), Value{},
982 NVVM::TMALoadMode::TILE,
985 adaptor.getPredicate());
990struct NVGPUTmaAsyncStoreOpLowering
991 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
992 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
994 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
995 ConversionPatternRewriter &rewriter)
const override {
996 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
997 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
999 adaptor.getSrc(), {});
1000 SmallVector<Value> coords = adaptor.getCoordinates();
1001 for (
auto [index, value] : llvm::enumerate(coords)) {
1006 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1007 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1008 NVVM::TMAStoreMode::TILE,
1009 adaptor.getPredicate());
1014struct NVGPUGenerateWarpgroupDescriptorLowering
1016 using ConvertOpToLLVMPattern<
1017 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1020 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter)
const override {
1023 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1025 nvgpu::TensorMapSwizzleKind swizzleKind =
1026 op.getTensorMap().getType().getSwizzle();
1029 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1030 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1031 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1034 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1035 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1036 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1039 auto ti64 =
b.getIntegerType(64);
1040 auto makeConst = [&](uint64_t index) -> Value {
1041 return LLVM::ConstantOp::create(
b, ti64,
b.getI64IntegerAttr(index));
1043 auto shiftLeft = [&](Value value,
unsigned shift) -> Value {
1044 return LLVM::ShlOp::create(
b, ti64, value, makeConst(shift));
1046 auto shiftRight = [&](Value value,
unsigned shift) -> Value {
1047 return LLVM::LShrOp::create(
b, ti64, value, makeConst(shift));
1049 auto insertBit = [&](Value desc, Value val,
int startBit) {
1050 return LLVM::OrOp::create(
b, ti64, desc, shiftLeft(val, startBit));
1053 int64_t sizeN = op.getTensorMap().
getType().getTensor().getDimSize(0);
1054 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1055 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1056 uint64_t offsetVal = 0;
1058 Value strideDim = makeConst(strideDimVal);
1059 Value leadDim = makeConst(leadDimVal);
1062 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1063 adaptor.getTensor(), {});
1064 Value basePtr = LLVM::PtrToIntOp::create(
b, ti64, baseAddr);
1066 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1068 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1069 startLeadBit = 16, startBaseAddrBit = 0;
1070 Value dsc = makeConst(0);
1072 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1074 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1076 dsc = insertBit(dsc, strideDim, startStrideBit);
1078 dsc = insertBit(dsc, leadDim, startLeadBit);
1080 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1082 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1083 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1084 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1085 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1086 <<
")\n start_addr : " << baseAddr;
1088 rewriter.replaceOp(op, dsc);
1094 return LLVM::ConstantOp::create(
b,
b.getIntegerType(64),
1095 b.getI32IntegerAttr(
index));
1102 enum CUtensorMapDataTypeEnum {
1103 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1104 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1105 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1106 CU_TENSOR_MAP_DATA_TYPE_INT32,
1107 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1108 CU_TENSOR_MAP_DATA_TYPE_INT64,
1109 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1110 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1111 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1112 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1113 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1114 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1115 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1119 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1121 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1123 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1125 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1127 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1129 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1131 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1133 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1135 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1137 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1139 llvm_unreachable(
"Not supported data type");
1142struct NVGPUTmaCreateDescriptorOpLowering
1144 using ConvertOpToLLVMPattern<
1145 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1147 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1149 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1150 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1151 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1153 Value tensorElementType =
1154 elementTypeAsLLVMConstant(
b, op.getTensor().getType().getElementType());
1155 auto promotedOperands = getTypeConverter()->promoteOperands(
1156 b.getLoc(), op->getOperands(), adaptor.getOperands(),
b);
1158 Value boxArrayPtr = LLVM::AllocaOp::create(
1159 b, llvmPointerType, llvmInt64Type, makeI64Const(
b, 5));
1160 for (
auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1161 Value gep = LLVM::GEPOp::create(
b, llvmPointerType, llvmPointerType,
1162 boxArrayPtr, makeI64Const(
b, index));
1163 LLVM::StoreOp::create(
b, value, gep);
1166 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1168 SmallVector<Value> arguments;
1169 arguments.push_back(promotedOperands[0]);
1170 arguments.push_back(promotedOperands[1]);
1171 arguments.push_back(tensorElementType);
1172 arguments.push_back(
1173 makeI64Const(
b, (
int)desc.getInterleave()));
1174 arguments.push_back(makeI64Const(
b, (
int)desc.getSwizzle()));
1175 arguments.push_back(makeI64Const(
b, (
int)desc.getL2promo()));
1176 arguments.push_back(makeI64Const(
b, (
int)desc.getOob()));
1177 arguments.push_back(boxArrayPtr);
1180 SmallVector<Type> argTypes = {
1190 FunctionCallBuilder hostRegisterCallBuilder = {
1191 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1193 hostRegisterCallBuilder.
create(
b.getLoc(),
b, arguments).getResult();
1195 rewriter.replaceOp(op, tensorMap);
1200struct NVGPUWarpgroupMmaOpLowering
1202 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1224 class WarpgroupGemm {
1225 nvgpu::WarpgroupMmaOp op;
1226 ImplicitLocOpBuilder b;
1230 int64_t totalM, totalN, totalK;
1233 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1236 int iterationM = 0, iterationN = 0, iterationK = 0;
1241 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1244 if (inputElemType.
isTF32()) {
1246 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1248 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1251 }
else if (inputElemType.
isInteger(1)) {
1254 llvm_unreachable(
"msg: not supported K shape");
1256 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1257 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1261 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1262 bool useF32 =
false)
const {
1263 auto getWgmmaType = [=](Type elemType) {
1265 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1266 if (elemType.
isF16())
1267 return NVVM::WGMMATypes::f16;
1269 return NVVM::WGMMATypes::bf16;
1270 if (isa<Float8E4M3FNType>(elemType))
1271 return NVVM::WGMMATypes::e4m3;
1272 if (isa<Float8E5M2Type>(elemType))
1273 return NVVM::WGMMATypes::e5m2;
1275 return NVVM::WGMMATypes::b1;
1277 return NVVM::WGMMATypes::s8;
1279 return NVVM::WGMMATypes::u8;
1281 return NVVM::WGMMATypes::s32;
1282 llvm_unreachable(
"unsupported type");
1284 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1289 generateWgmmaLayout(std::optional<bool> transpose)
const {
1290 if (transpose.value_or(
false))
1291 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1292 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1296 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1297 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1301 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1302 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1303 NVVM::WGMMAScaleOut::one);
1306 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1307 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1308 NVVM::WGMMAScaleIn::one);
1312 Value makeAdd(Value
lhs, Value
rhs) {
1313 return LLVM::AddOp::create(b,
lhs.getType(),
lhs,
rhs);
1334 Value iterateDescriptorA(Value desc,
int i,
int j,
int k) {
1335 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1336 Type elemA = matrixTypeA.getElementType();
1338 int tileShapeA = matrixTypeA.getDimSize(1);
1339 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1341 LDBG() <<
"\t\t[m: " << i <<
" n: " << j <<
" k: " << k
1342 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1346 return makeAdd(desc, makeI64Const(b, incrementVal));
1360 Value iterateDescriptorB(Value desc,
int i,
int j,
int k) {
1361 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1362 Type elemB = matrixTypeB.getElementType();
1364 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1366 LDBG() <<
"Descriptor B + " << incrementVal;
1369 return makeAdd(desc, makeI64Const(b, incrementVal));
1374 Value generateWgmma(
int i,
int j,
int k, Value matrixC) {
1375 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1376 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1377 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1378 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1379 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1380 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1382 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1383 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1385 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1386 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1388 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1389 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1391 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1392 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1394 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1395 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1396 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1397 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1398 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1400 auto overflow = NVVM::MMAIntOverflowAttr::get(
1401 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1403 return NVVM::WgmmaMmaAsyncOp::create(
1404 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1405 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1410 Value generateWgmmaGroup() {
1412 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1415 SmallVector<Value> wgmmaResults;
1416 for (
int i = 0; i < iterationM; ++i) {
1418 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1419 for (
int j = 0; j < iterationN; ++j)
1420 for (
int k = 0; k < iterationK; ++k)
1421 matrixC = generateWgmma(i, j, k, matrixC);
1422 wgmmaResults.push_back(matrixC);
1424 for (
auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1425 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1426 wgmmaResult, matrix, idx);
1432 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1434 : op(op), b(b), adaptor(adaptor) {
1436 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1437 totalN = op.getDescriptorB().
getType().getTensor().getDimSize(1);
1438 totalK = op.getDescriptorA().
getType().getTensor().getDimSize(1);
1439 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1440 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1446 op.getDescriptorA().getType().getTensor().getElementType());
1449 iterationM = totalM / wgmmaM;
1450 iterationN = totalN / wgmmaN;
1451 iterationK = totalK / wgmmaK;
1459 Value generateWarpgroupMma() {
1460 NVVM::WgmmaFenceAlignedOp::create(b);
1461 Value wgmmaResult = generateWgmmaGroup();
1462 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1463 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1468 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1469 ConversionPatternRewriter &rewriter)
const override {
1470 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1473 WarpgroupGemm warpgroupGemm(op,
b, adaptor);
1476 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1479 rewriter.replaceOp(op, wgmmaResult);
1484struct NVGPUWarpgroupMmaStoreOpLowering
1486 using ConvertOpToLLVMPattern<
1487 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1525 void storeFragmentedMatrix(ImplicitLocOpBuilder &
b, Value matrixD,
1528 Type i32 =
b.getI32Type();
1530 auto makeConst = [&](int32_t index) -> Value {
1531 return LLVM::ConstantOp::create(
b, i32,
b.getI32IntegerAttr(index));
1533 Value c1 = makeConst(1);
1534 Value c2 = makeConst(2);
1535 Value c4 = makeConst(4);
1536 Value c8 = makeConst(8);
1537 Value c16 = makeConst(16);
1540 auto makeMul = [&](Value
lhs, Value
rhs) -> Value {
1541 return LLVM::MulOp::create(
b,
lhs.getType(),
lhs,
rhs);
1543 auto makeAdd = [&](Value
lhs, Value
rhs) -> Value {
1544 return LLVM::AddOp::create(
b,
lhs.getType(),
lhs,
rhs);
1547 auto makeExtractAndStore = [&](
int i, Value wgmmaResult, Value x, Value y,
1549 Type it =
b.getIndexType();
1550 Value idx = arith::IndexCastOp::create(
b, it, x);
1551 Value idy0 = arith::IndexCastOp::create(
b, it, y);
1552 Value idy1 = arith::IndexCastOp::create(
b, it, makeAdd(y, c1));
1553 Value d0 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i);
1554 Value d1 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i + 1);
1555 memref::StoreOp::create(
b, d0, memref,
ValueRange{idx, idy0});
1556 memref::StoreOp::create(
b, d1, memref,
ValueRange{idx, idy1});
1559 Value tidx = NVVM::ThreadIdXOp::create(
b, i32);
1560 Value laneId = LLVM::URemOp::create(
b, i32, tidx, warpSize);
1561 Value warpId = LLVM::UDivOp::create(
b, i32, tidx, warpSize);
1562 Value lane4Id = LLVM::UDivOp::create(
b, i32, laneId, c4);
1563 Value lane4modId = LLVM::URemOp::create(
b, i32, laneId, c4);
1565 Value tj = makeMul(lane4modId, c2);
1566 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1568 ti = makeAdd(ti, makeConst(offset));
1570 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1573 constexpr unsigned numAdjacentRegisters = 2;
1575 constexpr unsigned numStackedMatrices = 2;
1577 size_t storeCount = (structType.getBody().size() /
1578 (numStackedMatrices * numAdjacentRegisters));
1580 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1581 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1582 for (
size_t j = 0; j < storeCount; ++j) {
1583 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1584 size_t structIndex = (i * numAdjacentRegisters) +
1585 (j * (numStackedMatrices * numAdjacentRegisters));
1586 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1592 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1593 ConversionPatternRewriter &rewriter)
const override {
1595 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1596 Value matriDValue = adaptor.getMatrixD();
1597 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1598 for (
auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1599 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1600 Value innerStructValue =
1601 LLVM::ExtractValueOp::create(
b, matriDValue, idx);
1602 storeFragmentedMatrix(
b, innerStructValue, op.getDstMemref(), offset);
1603 offset += structType.getBody().size();
1605 rewriter.eraseOp(op);
1610struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1612 using ConvertOpToLLVMPattern<
1613 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1615 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1616 ConversionPatternRewriter &rewriter)
const override {
1617 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1618 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1619 getTypeConverter()->convertType(op.getMatrixC().getType()));
1620 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1623 Value zero = LLVM::ConstantOp::create(
b, elemType,
b.getZeroAttr(elemType));
1624 Value packStruct = LLVM::PoisonOp::create(
b, packStructType);
1625 SmallVector<Value> innerStructs;
1627 for (
auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1628 auto structType = cast<LLVM::LLVMStructType>(s);
1629 Value structValue = LLVM::ExtractValueOp::create(
b, packStruct, idx);
1630 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1631 structValue = LLVM::InsertValueOp::create(
b, structType, structValue,
1632 zero, ArrayRef<int64_t>({i}));
1634 innerStructs.push_back(structValue);
1637 for (
auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1638 packStruct = LLVM::InsertValueOp::create(
b, packStruct.
getType(),
1639 packStruct, matrix, idx);
1641 rewriter.replaceOp(op, packStruct);
1646struct NVGPUTmaFenceOpLowering
1648 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1650 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1651 ConversionPatternRewriter &rewriter)
const override {
1652 MLIRContext *ctx = op.getContext();
1653 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1654 auto i32Ty =
b.getI32Type();
1655 Value tensormapSize =
1656 LLVM::ConstantOp::create(
b, i32Ty, rewriter.getI32IntegerAttr(128));
1659 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1661 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1662 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1668struct NVGPUTmaPrefetchOpLowering
1670 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1672 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1673 ConversionPatternRewriter &rewriter)
const override {
1674 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1675 op,
nullptr,
nullptr,
1676 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1677 mlir::UnitAttr::get(op.getContext()));
1683 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1685 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1686 ConversionPatternRewriter &rewriter)
const override {
1687 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1688 auto i64Ty =
b.getI64Type();
1689 auto f32Ty =
b.getF32Type();
1690 VectorType inTy = op.getIn().getType();
1692 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1693 Value ret1DVec = LLVM::PoisonOp::create(
b, llvm1DVectorTy);
1694 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1695 for (
int i = 0; i < numElems; i++) {
1696 Value idx = LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(i));
1697 Value elem = LLVM::ExtractElementOp::create(
b, inVec, idx);
1698 Value dst = NVVM::RcpApproxFtzF32Op::create(
b, f32Ty, elem);
1699 ret1DVec = LLVM::InsertElementOp::create(
b, ret1DVec, dst, idx);
1703 if (inTy.getRank() == 1) {
1704 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1708 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1709 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
1710 OpAdaptor adaptor(operands);
1711 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1721 NVGPUMBarrierCreateLowering,
1722 NVGPUMBarrierInitLowering,
1723 NVGPUMBarrierGetLowering,
1724 NVGPUMBarrierArriveLowering,
1725 NVGPUMBarrierArriveNoCompleteLowering,
1726 NVGPUMBarrierTestWaitLowering,
1727 NVGPUMBarrierTryWaitParityLowering,
1728 NVGPUTmaAsyncLoadOpLowering,
1729 NVGPUTmaAsyncStoreOpLowering,
1730 NVGPUTmaCreateDescriptorOpLowering,
1731 NVGPUTmaPrefetchOpLowering,
1732 NVGPUTmaFenceOpLowering,
1733 NVGPUMBarrierArriveExpectTxLowering,
1734 NVGPUGenerateWarpgroupDescriptorLowering,
1735 NVGPUWarpgroupMmaOpLowering,
1736 NVGPUWarpgroupMmaStoreOpLowering,
1737 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1738 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1739 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1740 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const