28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
34#define DEBUG_TYPE "nvgpu-to-nvvm"
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(
b,
b.getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
113 auto makeConst = [&](int32_t
index) ->
Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (
const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc,
result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty =
b.getI32Type();
179 Type f64Ty =
b.getF64Type();
180 Type f32Ty =
b.getF32Type();
181 Type i64Ty =
b.getI64Type();
182 Type bf16x2Ty = VectorType::get(2,
b.getBF16Type());
183 Type i8x4Ty = VectorType::get(4,
b.getI8Type());
184 Type i4x8Ty = VectorType::get(8,
b.getIntegerType(4));
185 Type f32x1Ty = VectorType::get(1, f32Ty);
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
188 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = LLVM::ExtractValueOp::create(
b, operand, i);
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == bf16x2Ty &&
196 operandPtxType == NVVM::MMATypes::bf16) ||
197 (arrayTy.getElementType() == f32x1Ty &&
198 operandPtxType == NVVM::MMATypes::tf32)) {
199 result.push_back(LLVM::BitcastOp::create(
b, i32Ty, toUse));
206 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
207 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
208 innerArrayTy.getElementType() == f64Ty ||
209 innerArrayTy.getElementType() == f32Ty)) {
210 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
211 idx < innerSize; idx++) {
212 result.push_back(LLVM::ExtractElementOp::create(
214 LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(idx))));
225 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
226 barrierType.getMemorySpace()));
231 nvgpu::MBarrierGroupType barrierType) {
235 IntegerAttr::get(IntegerType::get(context, 64),
236 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
244 nvgpu::MBarrierGroupType barrierType) {
246 MemRefLayoutAttrInterface layout;
247 return MemRefType::get({barrierType.getNumBarriers()},
248 IntegerType::get(context, 64), layout, memorySpace);
254 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
257 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter)
const override {
260 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
268 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
269 if (!vectorResultType) {
272 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
273 vectorResultType.getElementType());
275 int64_t num32BitRegs = vectorResultType.getDimSize(0);
277 Type ldMatrixResultType;
278 if (num32BitRegs > 1) {
279 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
280 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
282 ldMatrixResultType = rewriter.getI32Type();
285 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
288 adaptor.getSrcMemref(), adaptor.getIndices());
289 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
290 Value ldMatrixResult = NVVM::LdMatrixOp::create(
291 b, ldMatrixResultType, srcPtr,
293 op.getTranspose() ? NVVM::MMALayout::col
294 : NVVM::MMALayout::row,
295 shape, NVVM::LdStMatrixEltType::B16);
301 Type finalResultType = typeConverter->convertType(vectorResultType);
302 Value
result = LLVM::PoisonOp::create(
b, finalResultType);
303 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
305 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(
b, ldMatrixResult, i)
307 Value casted = LLVM::BitcastOp::create(
b, innerVectorType, i32Register);
311 rewriter.replaceOp(op,
result);
318static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
321 return NVVM::MMATypes::s8;
323 return NVVM::MMATypes::s4;
325 return NVVM::MMATypes::f16;
327 return NVVM::MMATypes::bf16;
329 return NVVM::MMATypes::f64;
331 return NVVM::MMATypes::tf32;
336 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
339 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter)
const override {
341 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
344 VectorType aType = op.getMatrixA().getType();
345 VectorType bType = op.getMatrixA().getType();
346 VectorType cType = op.getMatrixC().getType();
348 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
351 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
352 if (aType.getElementType().isF32() && !tf32Enabled)
355 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
357 return op->emitOpError(
"failed to deduce operand PTX types");
358 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
360 return op->emitOpError(
"failed to deduce operand PTX types");
361 std::optional<NVVM::MMATypes> ptxTypeC =
362 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
365 return op->emitError(
366 "could not infer the PTX type for the accumulator/result");
369 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
370 if (isa<IntegerType>(aType.getElementType()))
371 overflow = NVVM::MMAIntOverflow::satfinite;
373 SmallVector<Value> matA =
375 SmallVector<Value> matB =
377 SmallVector<Value> matC =
380 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
382 typeConverter->convertType(op->getResultTypes()[0]));
383 Value intrinsicResult =
384 NVVM::MmaOp::create(
b, intrinsicResTy, matA, matB, matC,
389 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
391 std::array<NVVM::MMALayout, 2>{
392 NVVM::MMALayout::row, NVVM::MMALayout::col});
394 desiredRetTy, intrinsicResult,
400struct ConvertNVGPUToNVVMPass
401 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
404 void runOnOperation()
override {
414 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
415 return converter.convertType(IntegerType::get(type.getContext(), 32));
417 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
418 Type elemType = type.getFragmented().getElementType();
419 int64_t sizeM = type.getFragmented().getDimSize(0);
420 int64_t sizeN = type.getFragmented().getDimSize(1);
424 numMembers = sizeN / 2;
425 else if (elemType.
isF16())
426 numMembers = sizeN / 4;
428 llvm_unreachable(
"unsupported type for warpgroup accumulator");
430 SmallVector<Type> innerStructBody;
431 for (
unsigned i = 0; i < numMembers; i++)
432 innerStructBody.push_back(elemType);
433 auto innerStructType =
434 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
436 SmallVector<Type> structBody;
438 structBody.push_back(innerStructType);
441 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
442 return converter.convertType(convertedType);
444 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
445 return converter.convertType(IntegerType::get(type.getContext(), 64));
447 converter.addConversion(
448 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
449 return converter.convertType(IntegerType::get(type.getContext(), 64));
451 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
452 return converter.convertType(
455 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
456 return LLVM::LLVMPointerType::get(type.getContext());
460 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
461 target.addLegalDialect<::mlir::arith::ArithDialect>();
462 target.addLegalDialect<::mlir::memref::MemRefDialect>();
463 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
465 converter, patterns,
target);
466 if (
failed(applyPartialConversion(getOperation(),
target,
467 std::move(patterns))))
473static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
477 llvm::raw_string_ostream ss(str);
478 for (
unsigned i = 0; i < matCSize; i++)
480 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
492static std::string buildMmaSparseAsmString(
493 const std::array<int64_t, 3> &
shape,
unsigned matASize,
unsigned matBSize,
494 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
495 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
496 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
497 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
498 return NVVM::stringifyMMATypes(ptxType);
502 llvm::raw_string_ostream ss(asmStr);
503 ss <<
"mma.sp.sync.aligned.m" <<
shape[0] <<
"n" <<
shape[1] <<
"k"
504 <<
shape[2] <<
".row.col.";
507 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
509 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
510 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
511 unsigned asmArgIdx = 0;
515 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
517 for (
unsigned i = 0; i < arrSize; i++)
518 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
521 ss <<
"$" << asmArgIdx++ <<
",";
522 assert(metaDataSelector <= 1);
523 ss <<
"0x" << metaDataSelector <<
";";
529static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
531 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
532 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
534 int64_t metadataSelector,
const std::array<int64_t, 3> &
shape,
535 Type intrinsicResultType) {
536 auto asmDialectAttr =
537 LLVM::AsmDialectAttr::get(
b.getContext(), LLVM::AsmDialect::AD_ATT);
539 const unsigned matASize = unpackedAData.size();
540 const unsigned matBSize = unpackedB.size();
541 const unsigned matCSize = unpackedC.size();
543 std::string asmStr = buildMmaSparseAsmString(
544 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
545 ptxTypeD, overflow, metadataSelector);
546 std::string constraintStr =
547 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
550 asmVals.reserve(matASize + matBSize + matCSize + 1);
552 llvm::append_range(asmVals, args);
553 asmVals.push_back(indexData);
555 return LLVM::InlineAsmOp::create(
b,
562 LLVM::TailCallKind::None,
568struct NVGPUMmaSparseSyncLowering
570 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
573 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter)
const override {
575 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
578 VectorType aType = op.getMatrixA().getType();
579 VectorType bType = op.getMatrixB().getType();
580 VectorType cType = op.getMatrixC().getType();
582 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
584 return op->emitOpError(
"failed to deduce operand PTX types");
585 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
587 return op->emitOpError(
"failed to deduce operand PTX types");
588 std::optional<NVVM::MMATypes> ptxTypeC =
589 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
592 return op->emitError(
593 "could not infer the PTX type for the accumulator/result");
596 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
597 if (aType.getElementType().isF32() && !tf32Enabled)
601 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
602 if (isa<IntegerType>(aType.getElementType()))
603 overflow = NVVM::MMAIntOverflow::satfinite;
605 SmallVector<Value> matA =
607 SmallVector<Value> matB =
609 SmallVector<Value> matC =
612 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
614 typeConverter->convertType(op->getResultTypes()[0]));
617 Value sparseMetadata = adaptor.getSparseMetadata();
618 if (sparseMetadata.
getType() != VectorType::get(2, rewriter.getI16Type()))
619 return op->emitOpError() <<
"Expected metadata type to be LLVM "
620 "VectorType of 2 i16 elements";
622 LLVM::BitcastOp::create(
b, rewriter.getI32Type(), sparseMetadata);
624 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
625 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
626 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
628 if (
failed(intrinsicResult))
631 assert((*intrinsicResult).getNumResults() == 1 &&
632 "expected inline asm op returns a single LLVM struct type");
635 (*intrinsicResult)->getResult(0), rewriter));
640struct NVGPUAsyncCopyLowering
642 using ConvertOpToLLVMPattern<
643 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
646 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
647 ConversionPatternRewriter &rewriter)
const override {
648 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
649 Location loc = op.getLoc();
650 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
653 adaptor.getDst(), adaptor.getDstIndices());
654 FailureOr<unsigned> dstAddressSpace =
655 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
656 if (
failed(dstAddressSpace))
657 return rewriter.notifyMatchFailure(
658 loc,
"destination memref address space not convertible to integer");
660 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
661 FailureOr<unsigned> srcAddressSpace =
662 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
663 if (
failed(srcAddressSpace))
664 return rewriter.notifyMatchFailure(
665 loc,
"source memref address space not convertible to integer");
669 adaptor.getSrcIndices());
671 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
672 op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
673 scrPtr = LLVM::AddrSpaceCastOp::create(
b, srcPointerGlobalType, scrPtr);
674 int64_t dstElements = adaptor.getDstElements().getZExtValue();
675 int64_t sizeInBytes =
676 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
681 Value srcBytes = adaptor.getSrcElements();
688 LLVM::ConstantOp::create(
b,
b.getI32Type(),
b.getI32IntegerAttr(3));
689 Value bitwidth = LLVM::ConstantOp::create(
691 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
692 Value srcElementsI32 = LLVM::TruncOp::create(
b,
b.getI32Type(), srcBytes);
693 srcBytes = LLVM::LShrOp::create(
694 b, LLVM::MulOp::create(
b, bitwidth, srcElementsI32), c3I32);
698 NVVM::LoadCacheModifierKind cacheModifier =
699 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
700 ? NVVM::LoadCacheModifierKind::CG
701 : NVVM::LoadCacheModifierKind::CA;
703 NVVM::CpAsyncOp::create(
704 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
705 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
710 LLVM::ConstantOp::create(
b, IntegerType::get(op.getContext(), 32),
711 rewriter.getI32IntegerAttr(0));
712 rewriter.replaceOp(op, zero);
717struct NVGPUAsyncCreateGroupLowering
719 using ConvertOpToLLVMPattern<
720 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
723 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter)
const override {
725 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
727 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
728 IntegerType::get(op.getContext(), 32),
729 rewriter.getI32IntegerAttr(0));
730 rewriter.replaceOp(op, zero);
735struct NVGPUAsyncWaitLowering
737 using ConvertOpToLLVMPattern<
738 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
741 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
742 ConversionPatternRewriter &rewriter)
const override {
744 int32_t numGroups = adaptor.getNumGroups().value_or(0);
745 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
746 rewriter.eraseOp(op);
752struct NVGPUMBarrierCreateLowering
754 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
756 template <
typename moduleT>
757 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
758 Operation *funcOp, moduleT moduleOp,
759 MemRefType barrierType)
const {
760 SymbolTable symbolTable(moduleOp);
761 OpBuilder::InsertionGuard guard(rewriter);
762 rewriter.setInsertionPoint(&moduleOp.front());
763 auto global = memref::GlobalOp::create(
764 rewriter, funcOp->
getLoc(),
"__mbarrier",
765 rewriter.getStringAttr(
"private"),
769 rewriter.getI64IntegerAttr(8));
770 symbolTable.insert(global);
775 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter)
const override {
779 rewriter.getContext(), op.getBarriers().getType());
781 memref::GlobalOp global;
783 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
785 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
787 rewriter.setInsertionPoint(op);
788 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
795template <
typename SourceOp>
798 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
800 Value getMbarrierPtr(ImplicitLocOpBuilder &
b,
801 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
803 ConversionPatternRewriter &rewriter)
const {
804 MemRefType mbarrierMemrefType =
807 rewriter,
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
811struct NVGPUMBarrierGetLowering
812 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
813 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
816 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter)
const override {
818 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
819 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
820 rewriter.setInsertionPoint(op);
821 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
822 adaptor.getMbarId(), rewriter);
823 Type resType = op.getMbarrierPointer().getType();
824 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
830struct NVGPUMBarrierInitLowering
831 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
832 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
835 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
836 ConversionPatternRewriter &rewriter)
const override {
837 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
838 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
839 rewriter.setInsertionPoint(op);
840 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
841 adaptor.getMbarId(), rewriter);
843 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
844 adaptor.getPredicate());
850struct NVGPUMBarrierArriveLowering
851 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
852 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
854 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter)
const override {
856 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
858 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
859 adaptor.getMbarId(), rewriter);
860 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
867struct NVGPUMBarrierArriveNoCompleteLowering
868 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
869 using MBarrierBasePattern<
870 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
872 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
873 ConversionPatternRewriter &rewriter)
const override {
874 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
876 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
877 adaptor.getMbarId(), rewriter);
878 Type tokenType = getTypeConverter()->convertType(
879 nvgpu::MBarrierTokenType::get(op->getContext()));
881 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
882 op, tokenType, barrier, count);
888struct NVGPUMBarrierTestWaitLowering
889 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
890 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
892 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
893 ConversionPatternRewriter &rewriter)
const override {
894 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
896 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
897 adaptor.getMbarId(), rewriter);
898 Type retType = rewriter.getI1Type();
899 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
905struct NVGPUMBarrierArriveExpectTxLowering
906 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
907 using MBarrierBasePattern<
908 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
910 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
911 ConversionPatternRewriter &rewriter)
const override {
912 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
914 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
915 adaptor.getMbarId(), rewriter);
916 Value txcount =
truncToI32(
b, adaptor.getTxcount());
917 NVVM::MBarrierArriveExpectTxOp::create(
918 rewriter, op->getLoc(), barrier, txcount,
919 NVVM::MemScopeKind::CTA,
921 adaptor.getPredicate());
922 rewriter.eraseOp(op);
927struct NVGPUMBarrierTryWaitParityLowering
928 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
929 using MBarrierBasePattern<
930 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
932 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
933 ConversionPatternRewriter &rewriter)
const override {
934 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
936 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
937 adaptor.getMbarId(), rewriter);
940 LLVM::ZExtOp::create(
b,
b.getI32Type(), adaptor.getPhaseParity());
941 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
947struct NVGPUTmaAsyncLoadOpLowering
948 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
949 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
951 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter)
const override {
953 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
954 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
956 adaptor.getDst(), {});
960 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
962 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
963 dest = LLVM::AddrSpaceCastOp::create(
b, ptrSharedClusterType, dest);
966 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
967 adaptor.getMbarId(), rewriter);
969 SmallVector<Value> coords = adaptor.getCoordinates();
970 for (
auto [index, value] : llvm::enumerate(coords)) {
975 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
976 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
977 ValueRange{}, adaptor.getMulticastMask(), Value{},
978 NVVM::TMALoadMode::TILE,
981 adaptor.getPredicate());
986struct NVGPUTmaAsyncStoreOpLowering
987 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
988 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
990 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
991 ConversionPatternRewriter &rewriter)
const override {
992 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
993 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
995 adaptor.getSrc(), {});
996 SmallVector<Value> coords = adaptor.getCoordinates();
997 for (
auto [index, value] : llvm::enumerate(coords)) {
1002 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1003 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1004 NVVM::TMAStoreMode::TILE,
1005 adaptor.getPredicate());
1010struct NVGPUGenerateWarpgroupDescriptorLowering
1012 using ConvertOpToLLVMPattern<
1013 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1016 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1017 ConversionPatternRewriter &rewriter)
const override {
1019 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1021 nvgpu::TensorMapSwizzleKind swizzleKind =
1022 op.getTensorMap().getType().getSwizzle();
1025 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1026 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1027 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1030 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1031 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1032 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1035 auto ti64 =
b.getIntegerType(64);
1036 auto makeConst = [&](uint64_t index) -> Value {
1037 return LLVM::ConstantOp::create(
b, ti64,
b.getI64IntegerAttr(index));
1039 auto shiftLeft = [&](Value value,
unsigned shift) -> Value {
1040 return LLVM::ShlOp::create(
b, ti64, value, makeConst(shift));
1042 auto shiftRight = [&](Value value,
unsigned shift) -> Value {
1043 return LLVM::LShrOp::create(
b, ti64, value, makeConst(shift));
1045 auto insertBit = [&](Value desc, Value val,
int startBit) {
1046 return LLVM::OrOp::create(
b, ti64, desc, shiftLeft(val, startBit));
1049 int64_t sizeN = op.getTensorMap().
getType().getTensor().getDimSize(0);
1050 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1051 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1052 uint64_t offsetVal = 0;
1054 Value strideDim = makeConst(strideDimVal);
1055 Value leadDim = makeConst(leadDimVal);
1058 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1059 adaptor.getTensor(), {});
1060 Value basePtr = LLVM::PtrToIntOp::create(
b, ti64, baseAddr);
1062 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1064 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1065 startLeadBit = 16, startBaseAddrBit = 0;
1066 Value dsc = makeConst(0);
1068 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1070 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1072 dsc = insertBit(dsc, strideDim, startStrideBit);
1074 dsc = insertBit(dsc, leadDim, startLeadBit);
1076 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1078 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1079 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1080 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1081 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1082 <<
")\n start_addr : " << baseAddr;
1084 rewriter.replaceOp(op, dsc);
1090 return LLVM::ConstantOp::create(
b,
b.getIntegerType(64),
1091 b.getI32IntegerAttr(
index));
1098 enum CUtensorMapDataTypeEnum {
1099 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1100 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1101 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1102 CU_TENSOR_MAP_DATA_TYPE_INT32,
1103 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1104 CU_TENSOR_MAP_DATA_TYPE_INT64,
1105 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1106 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1107 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1108 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1109 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1110 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1111 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1115 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1117 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1119 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1121 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1123 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1125 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1127 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1129 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1131 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1133 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1135 llvm_unreachable(
"Not supported data type");
1138struct NVGPUTmaCreateDescriptorOpLowering
1140 using ConvertOpToLLVMPattern<
1141 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1143 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter)
const override {
1145 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1146 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1147 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1149 Value tensorElementType =
1150 elementTypeAsLLVMConstant(
b, op.getTensor().getType().getElementType());
1151 auto promotedOperands = getTypeConverter()->promoteOperands(
1152 b.getLoc(), op->getOperands(), adaptor.getOperands(),
b);
1154 Value boxArrayPtr = LLVM::AllocaOp::create(
1155 b, llvmPointerType, llvmInt64Type, makeI64Const(
b, 5));
1156 for (
auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1157 Value gep = LLVM::GEPOp::create(
b, llvmPointerType, llvmPointerType,
1158 boxArrayPtr, makeI64Const(
b, index));
1159 LLVM::StoreOp::create(
b, value, gep);
1162 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1164 SmallVector<Value> arguments;
1165 arguments.push_back(promotedOperands[0]);
1166 arguments.push_back(promotedOperands[1]);
1167 arguments.push_back(tensorElementType);
1168 arguments.push_back(
1169 makeI64Const(
b, (
int)desc.getInterleave()));
1170 arguments.push_back(makeI64Const(
b, (
int)desc.getSwizzle()));
1171 arguments.push_back(makeI64Const(
b, (
int)desc.getL2promo()));
1172 arguments.push_back(makeI64Const(
b, (
int)desc.getOob()));
1173 arguments.push_back(boxArrayPtr);
1176 SmallVector<Type> argTypes = {
1186 FunctionCallBuilder hostRegisterCallBuilder = {
1187 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1189 hostRegisterCallBuilder.
create(
b.getLoc(),
b, arguments).getResult();
1191 rewriter.replaceOp(op, tensorMap);
1196struct NVGPUWarpgroupMmaOpLowering
1198 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1220 class WarpgroupGemm {
1221 nvgpu::WarpgroupMmaOp op;
1222 ImplicitLocOpBuilder b;
1226 int64_t totalM, totalN, totalK;
1229 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1232 int iterationM = 0, iterationN = 0, iterationK = 0;
1237 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1240 if (inputElemType.
isTF32()) {
1242 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1244 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1247 }
else if (inputElemType.
isInteger(1)) {
1250 llvm_unreachable(
"msg: not supported K shape");
1252 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1253 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1257 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1258 bool useF32 =
false)
const {
1259 auto getWgmmaType = [=](Type elemType) {
1261 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1262 if (elemType.
isF16())
1263 return NVVM::WGMMATypes::f16;
1265 return NVVM::WGMMATypes::bf16;
1266 if (isa<Float8E4M3FNType>(elemType))
1267 return NVVM::WGMMATypes::e4m3;
1268 if (isa<Float8E5M2Type>(elemType))
1269 return NVVM::WGMMATypes::e5m2;
1271 return NVVM::WGMMATypes::b1;
1273 return NVVM::WGMMATypes::s8;
1275 return NVVM::WGMMATypes::u8;
1277 return NVVM::WGMMATypes::s32;
1278 llvm_unreachable(
"unsupported type");
1280 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1285 generateWgmmaLayout(std::optional<bool> transpose)
const {
1286 if (transpose.value_or(
false))
1287 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1288 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1292 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1293 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1297 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1298 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1299 NVVM::WGMMAScaleOut::one);
1302 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1303 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1304 NVVM::WGMMAScaleIn::one);
1308 Value makeAdd(Value
lhs, Value
rhs) {
1309 return LLVM::AddOp::create(b,
lhs.getType(),
lhs,
rhs);
1330 Value iterateDescriptorA(Value desc,
int i,
int j,
int k) {
1331 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1332 Type elemA = matrixTypeA.getElementType();
1334 int tileShapeA = matrixTypeA.getDimSize(1);
1335 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1337 LDBG() <<
"\t\t[m: " << i <<
" n: " << j <<
" k: " << k
1338 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1342 return makeAdd(desc, makeI64Const(b, incrementVal));
1356 Value iterateDescriptorB(Value desc,
int i,
int j,
int k) {
1357 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1358 Type elemB = matrixTypeB.getElementType();
1360 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1362 LDBG() <<
"Descriptor B + " << incrementVal;
1365 return makeAdd(desc, makeI64Const(b, incrementVal));
1370 Value generateWgmma(
int i,
int j,
int k, Value matrixC) {
1371 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1372 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1373 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1374 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1375 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1376 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1378 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1379 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1381 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1382 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1384 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1385 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1387 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1388 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1390 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1391 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1392 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1393 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1394 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1396 auto overflow = NVVM::MMAIntOverflowAttr::get(
1397 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1399 return NVVM::WgmmaMmaAsyncOp::create(
1400 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1401 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1406 Value generateWgmmaGroup() {
1408 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1411 SmallVector<Value> wgmmaResults;
1412 for (
int i = 0; i < iterationM; ++i) {
1414 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1415 for (
int j = 0; j < iterationN; ++j)
1416 for (
int k = 0; k < iterationK; ++k)
1417 matrixC = generateWgmma(i, j, k, matrixC);
1418 wgmmaResults.push_back(matrixC);
1420 for (
auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1421 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1422 wgmmaResult, matrix, idx);
1428 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1430 : op(op), b(b), adaptor(adaptor) {
1432 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1433 totalN = op.getDescriptorB().
getType().getTensor().getDimSize(1);
1434 totalK = op.getDescriptorA().
getType().getTensor().getDimSize(1);
1435 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1436 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1442 op.getDescriptorA().getType().getTensor().getElementType());
1445 iterationM = totalM / wgmmaM;
1446 iterationN = totalN / wgmmaN;
1447 iterationK = totalK / wgmmaK;
1455 Value generateWarpgroupMma() {
1456 NVVM::WgmmaFenceAlignedOp::create(b);
1457 Value wgmmaResult = generateWgmmaGroup();
1458 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1459 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1464 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1465 ConversionPatternRewriter &rewriter)
const override {
1466 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1469 WarpgroupGemm warpgroupGemm(op,
b, adaptor);
1472 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1475 rewriter.replaceOp(op, wgmmaResult);
1480struct NVGPUWarpgroupMmaStoreOpLowering
1482 using ConvertOpToLLVMPattern<
1483 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1521 void storeFragmentedMatrix(ImplicitLocOpBuilder &
b, Value matrixD,
1524 Type i32 =
b.getI32Type();
1526 auto makeConst = [&](int32_t index) -> Value {
1527 return LLVM::ConstantOp::create(
b, i32,
b.getI32IntegerAttr(index));
1529 Value c1 = makeConst(1);
1530 Value c2 = makeConst(2);
1531 Value c4 = makeConst(4);
1532 Value c8 = makeConst(8);
1533 Value c16 = makeConst(16);
1536 auto makeMul = [&](Value
lhs, Value
rhs) -> Value {
1537 return LLVM::MulOp::create(
b,
lhs.getType(),
lhs,
rhs);
1539 auto makeAdd = [&](Value
lhs, Value
rhs) -> Value {
1540 return LLVM::AddOp::create(
b,
lhs.getType(),
lhs,
rhs);
1543 auto makeExtractAndStore = [&](
int i, Value wgmmaResult, Value x, Value y,
1545 Type it =
b.getIndexType();
1546 Value idx = arith::IndexCastOp::create(
b, it, x);
1547 Value idy0 = arith::IndexCastOp::create(
b, it, y);
1548 Value idy1 = arith::IndexCastOp::create(
b, it, makeAdd(y, c1));
1549 Value d0 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i);
1550 Value d1 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i + 1);
1551 memref::StoreOp::create(
b, d0, memref,
ValueRange{idx, idy0});
1552 memref::StoreOp::create(
b, d1, memref,
ValueRange{idx, idy1});
1555 Value tidx = NVVM::ThreadIdXOp::create(
b, i32);
1556 Value laneId = LLVM::URemOp::create(
b, i32, tidx, warpSize);
1557 Value warpId = LLVM::UDivOp::create(
b, i32, tidx, warpSize);
1558 Value lane4Id = LLVM::UDivOp::create(
b, i32, laneId, c4);
1559 Value lane4modId = LLVM::URemOp::create(
b, i32, laneId, c4);
1561 Value tj = makeMul(lane4modId, c2);
1562 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1564 ti = makeAdd(ti, makeConst(offset));
1566 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1569 constexpr unsigned numAdjacentRegisters = 2;
1571 constexpr unsigned numStackedMatrices = 2;
1573 size_t storeCount = (structType.getBody().size() /
1574 (numStackedMatrices * numAdjacentRegisters));
1576 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1577 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1578 for (
size_t j = 0; j < storeCount; ++j) {
1579 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1580 size_t structIndex = (i * numAdjacentRegisters) +
1581 (j * (numStackedMatrices * numAdjacentRegisters));
1582 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1588 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1589 ConversionPatternRewriter &rewriter)
const override {
1591 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1592 Value matriDValue = adaptor.getMatrixD();
1593 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1594 for (
auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1595 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1596 Value innerStructValue =
1597 LLVM::ExtractValueOp::create(
b, matriDValue, idx);
1598 storeFragmentedMatrix(
b, innerStructValue, op.getDstMemref(), offset);
1599 offset += structType.getBody().size();
1601 rewriter.eraseOp(op);
1606struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1608 using ConvertOpToLLVMPattern<
1609 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1611 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1612 ConversionPatternRewriter &rewriter)
const override {
1613 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1614 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1615 getTypeConverter()->convertType(op.getMatrixC().getType()));
1616 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1619 Value zero = LLVM::ConstantOp::create(
b, elemType,
b.getZeroAttr(elemType));
1620 Value packStruct = LLVM::PoisonOp::create(
b, packStructType);
1621 SmallVector<Value> innerStructs;
1623 for (
auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1624 auto structType = cast<LLVM::LLVMStructType>(s);
1625 Value structValue = LLVM::ExtractValueOp::create(
b, packStruct, idx);
1626 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1627 structValue = LLVM::InsertValueOp::create(
b, structType, structValue,
1628 zero, ArrayRef<int64_t>({i}));
1630 innerStructs.push_back(structValue);
1633 for (
auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1634 packStruct = LLVM::InsertValueOp::create(
b, packStruct.
getType(),
1635 packStruct, matrix, idx);
1637 rewriter.replaceOp(op, packStruct);
1642struct NVGPUTmaFenceOpLowering
1644 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1646 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1647 ConversionPatternRewriter &rewriter)
const override {
1648 MLIRContext *ctx = op.getContext();
1649 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1650 auto i32Ty =
b.getI32Type();
1651 Value tensormapSize =
1652 LLVM::ConstantOp::create(
b, i32Ty, rewriter.getI32IntegerAttr(128));
1655 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1657 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1658 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1664struct NVGPUTmaPrefetchOpLowering
1666 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1668 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter)
const override {
1670 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1671 op,
nullptr,
nullptr,
1672 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1673 mlir::UnitAttr::get(op.getContext()));
1679 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1681 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1682 ConversionPatternRewriter &rewriter)
const override {
1683 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1684 auto i64Ty =
b.getI64Type();
1685 auto f32Ty =
b.getF32Type();
1686 VectorType inTy = op.getIn().getType();
1688 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1689 Value ret1DVec = LLVM::PoisonOp::create(
b, llvm1DVectorTy);
1690 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1691 for (
int i = 0; i < numElems; i++) {
1692 Value idx = LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(i));
1693 Value elem = LLVM::ExtractElementOp::create(
b, inVec, idx);
1694 Value dst = NVVM::RcpApproxFtzF32Op::create(
b, f32Ty, elem);
1695 ret1DVec = LLVM::InsertElementOp::create(
b, ret1DVec, dst, idx);
1699 if (inTy.getRank() == 1) {
1700 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1704 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1705 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
1706 OpAdaptor adaptor(operands);
1707 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1721 typeConverter, [](gpu::AddressSpace space) ->
unsigned {
1723 case gpu::AddressSpace::Global:
1724 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
1725 case gpu::AddressSpace::Workgroup:
1726 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
1727 case gpu::AddressSpace::Private:
1729 case gpu::AddressSpace::Constant:
1730 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Constant);
1732 llvm_unreachable(
"unknown address space enum value");
1739 NVGPUMBarrierCreateLowering,
1740 NVGPUMBarrierInitLowering,
1741 NVGPUMBarrierGetLowering,
1742 NVGPUMBarrierArriveLowering,
1743 NVGPUMBarrierArriveNoCompleteLowering,
1744 NVGPUMBarrierTestWaitLowering,
1745 NVGPUMBarrierTryWaitParityLowering,
1746 NVGPUTmaAsyncLoadOpLowering,
1747 NVGPUTmaAsyncStoreOpLowering,
1748 NVGPUTmaCreateDescriptorOpLowering,
1749 NVGPUTmaPrefetchOpLowering,
1750 NVGPUTmaFenceOpLowering,
1751 NVGPUMBarrierArriveExpectTxLowering,
1752 NVGPUGenerateWarpgroupDescriptorLowering,
1753 NVGPUWarpgroupMmaOpLowering,
1754 NVGPUWarpgroupMmaStoreOpLowering,
1755 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1756 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1757 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1758 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const