28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
34#define DEBUG_TYPE "nvgpu-to-nvvm"
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(
b,
b.getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
113 auto makeConst = [&](int32_t
index) ->
Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (
const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc,
result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty =
b.getI32Type();
179 Type f64Ty =
b.getF64Type();
180 Type f32Ty =
b.getF32Type();
181 Type i64Ty =
b.getI64Type();
182 Type i8x4Ty = VectorType::get(4,
b.getI8Type());
183 Type i4x8Ty = VectorType::get(8,
b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
187 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(
b, operand, i);
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(
b, i32Ty, toUse));
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
211 LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(idx))));
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
228 nvgpu::MBarrierGroupType barrierType) {
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
241 nvgpu::MBarrierGroupType barrierType) {
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter)
const override {
257 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279 ldMatrixResultType = rewriter.getI32Type();
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 shape, NVVM::LdStMatrixEltType::B16);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value
result = LLVM::PoisonOp::create(
b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(
b, ldMatrixResult, i)
304 Value casted = LLVM::BitcastOp::create(
b, innerVectorType, i32Register);
308 rewriter.replaceOp(op,
result);
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override {
336 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
368 SmallVector<Value> matA =
370 SmallVector<Value> matB =
372 SmallVector<Value> matC =
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(
b, intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395struct ConvertNVGPUToNVVMPass
399 void runOnOperation()
override {
405 converter, [](gpu::AddressSpace space) ->
unsigned {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409 case gpu::AddressSpace::Workgroup:
410 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411 case gpu::AddressSpace::Private:
414 llvm_unreachable(
"unknown address space enum value");
415 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
420 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
421 return converter.convertType(IntegerType::get(type.getContext(), 32));
423 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
424 Type elemType = type.getFragmented().getElementType();
425 int64_t sizeM = type.getFragmented().getDimSize(0);
426 int64_t sizeN = type.getFragmented().getDimSize(1);
430 numMembers = sizeN / 2;
431 else if (elemType.
isF16())
432 numMembers = sizeN / 4;
434 llvm_unreachable(
"unsupported type for warpgroup accumulator");
436 SmallVector<Type> innerStructBody;
437 for (
unsigned i = 0; i < numMembers; i++)
438 innerStructBody.push_back(elemType);
439 auto innerStructType =
440 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
442 SmallVector<Type> structBody;
444 structBody.push_back(innerStructType);
447 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448 return converter.convertType(convertedType);
450 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
451 return converter.convertType(IntegerType::get(type.getContext(), 64));
453 converter.addConversion(
454 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
455 return converter.convertType(IntegerType::get(type.getContext(), 64));
457 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
458 return converter.convertType(
461 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
462 return LLVM::LLVMPointerType::get(type.getContext());
466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467 target.addLegalDialect<::mlir::arith::ArithDialect>();
468 target.addLegalDialect<::mlir::memref::MemRefDialect>();
469 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
472 if (
failed(applyPartialConversion(getOperation(),
target,
479static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
483 llvm::raw_string_ostream ss(str);
484 for (
unsigned i = 0; i < matCSize; i++)
486 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
498static std::string buildMmaSparseAsmString(
499 const std::array<int64_t, 3> &
shape,
unsigned matASize,
unsigned matBSize,
500 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
503 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504 return NVVM::stringifyMMATypes(ptxType);
508 llvm::raw_string_ostream ss(asmStr);
509 ss <<
"mma.sp.sync.aligned.m" <<
shape[0] <<
"n" <<
shape[1] <<
"k"
510 <<
shape[2] <<
".row.col.";
513 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
515 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
516 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
517 unsigned asmArgIdx = 0;
521 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
523 for (
unsigned i = 0; i < arrSize; i++)
524 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
527 ss <<
"$" << asmArgIdx++ <<
",";
528 assert(metaDataSelector <= 1);
529 ss <<
"0x" << metaDataSelector <<
";";
535static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
537 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
540 int64_t metadataSelector,
const std::array<int64_t, 3> &
shape,
541 Type intrinsicResultType) {
542 auto asmDialectAttr =
543 LLVM::AsmDialectAttr::get(
b.getContext(), LLVM::AsmDialect::AD_ATT);
545 const unsigned matASize = unpackedAData.size();
546 const unsigned matBSize = unpackedB.size();
547 const unsigned matCSize = unpackedC.size();
549 std::string asmStr = buildMmaSparseAsmString(
550 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551 ptxTypeD, overflow, metadataSelector);
552 std::string constraintStr =
553 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
556 asmVals.reserve(matASize + matBSize + matCSize + 1);
558 llvm::append_range(asmVals, args);
559 asmVals.push_back(indexData);
561 return LLVM::InlineAsmOp::create(
b,
568 LLVM::TailCallKind::None,
574struct NVGPUMmaSparseSyncLowering
576 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
579 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
580 ConversionPatternRewriter &rewriter)
const override {
581 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
584 VectorType aType = op.getMatrixA().getType();
585 VectorType bType = op.getMatrixB().getType();
586 VectorType cType = op.getMatrixC().getType();
588 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
590 return op->emitOpError(
"failed to deduce operand PTX types");
591 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
593 return op->emitOpError(
"failed to deduce operand PTX types");
594 std::optional<NVVM::MMATypes> ptxTypeC =
595 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
598 return op->emitError(
599 "could not infer the PTX type for the accumulator/result");
602 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603 if (aType.getElementType().isF32() && !tf32Enabled)
607 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608 if (isa<IntegerType>(aType.getElementType()))
609 overflow = NVVM::MMAIntOverflow::satfinite;
611 SmallVector<Value> matA =
613 SmallVector<Value> matB =
615 SmallVector<Value> matC =
618 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
620 typeConverter->convertType(op->getResultTypes()[0]));
623 Value sparseMetadata = adaptor.getSparseMetadata();
624 if (sparseMetadata.
getType() != VectorType::get(2, rewriter.getI16Type()))
625 return op->emitOpError() <<
"Expected metadata type to be LLVM "
626 "VectorType of 2 i16 elements";
628 LLVM::BitcastOp::create(
b, rewriter.getI32Type(), sparseMetadata);
630 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
634 if (
failed(intrinsicResult))
637 assert((*intrinsicResult).getNumResults() == 1 &&
638 "expected inline asm op returns a single LLVM struct type");
641 (*intrinsicResult)->getResult(0), rewriter));
646struct NVGPUAsyncCopyLowering
648 using ConvertOpToLLVMPattern<
649 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
652 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter)
const override {
654 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
655 Location loc = op.getLoc();
656 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
659 adaptor.getDst(), adaptor.getDstIndices());
660 FailureOr<unsigned> dstAddressSpace =
661 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662 if (
failed(dstAddressSpace))
663 return rewriter.notifyMatchFailure(
664 loc,
"destination memref address space not convertible to integer");
666 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667 FailureOr<unsigned> srcAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669 if (
failed(srcAddressSpace))
670 return rewriter.notifyMatchFailure(
671 loc,
"source memref address space not convertible to integer");
675 adaptor.getSrcIndices());
677 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
678 op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679 scrPtr = LLVM::AddrSpaceCastOp::create(
b, srcPointerGlobalType, scrPtr);
680 int64_t dstElements = adaptor.getDstElements().getZExtValue();
681 int64_t sizeInBytes =
682 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
687 Value srcBytes = adaptor.getSrcElements();
694 LLVM::ConstantOp::create(
b,
b.getI32Type(),
b.getI32IntegerAttr(3));
695 Value bitwidth = LLVM::ConstantOp::create(
697 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
698 Value srcElementsI32 = LLVM::TruncOp::create(
b,
b.getI32Type(), srcBytes);
699 srcBytes = LLVM::LShrOp::create(
700 b, LLVM::MulOp::create(
b, bitwidth, srcElementsI32), c3I32);
704 NVVM::LoadCacheModifierKind cacheModifier =
705 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
706 ? NVVM::LoadCacheModifierKind::CG
707 : NVVM::LoadCacheModifierKind::CA;
709 NVVM::CpAsyncOp::create(
710 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
711 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
716 LLVM::ConstantOp::create(
b, IntegerType::get(op.getContext(), 32),
717 rewriter.getI32IntegerAttr(0));
718 rewriter.replaceOp(op, zero);
723struct NVGPUAsyncCreateGroupLowering
725 using ConvertOpToLLVMPattern<
726 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
729 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
730 ConversionPatternRewriter &rewriter)
const override {
731 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
733 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
734 IntegerType::get(op.getContext(), 32),
735 rewriter.getI32IntegerAttr(0));
736 rewriter.replaceOp(op, zero);
741struct NVGPUAsyncWaitLowering
743 using ConvertOpToLLVMPattern<
744 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
747 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter)
const override {
750 int32_t numGroups = adaptor.getNumGroups().value_or(0);
751 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
752 rewriter.eraseOp(op);
758struct NVGPUMBarrierCreateLowering
760 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
762 template <
typename moduleT>
763 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
764 Operation *funcOp, moduleT moduleOp,
765 MemRefType barrierType)
const {
766 SymbolTable symbolTable(moduleOp);
767 OpBuilder::InsertionGuard guard(rewriter);
768 rewriter.setInsertionPoint(&moduleOp.front());
769 auto global = memref::GlobalOp::create(
770 rewriter, funcOp->
getLoc(),
"__mbarrier",
771 rewriter.getStringAttr(
"private"),
775 rewriter.getI64IntegerAttr(8));
776 symbolTable.insert(global);
781 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter)
const override {
785 rewriter.getContext(), op.getBarriers().getType());
787 memref::GlobalOp global;
789 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
793 rewriter.setInsertionPoint(op);
794 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
801template <
typename SourceOp>
804 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
806 Value getMbarrierPtr(ImplicitLocOpBuilder &
b,
807 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
809 ConversionPatternRewriter &rewriter)
const {
810 MemRefType mbarrierMemrefType =
813 rewriter,
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
817struct NVGPUMBarrierGetLowering
818 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
822 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
823 ConversionPatternRewriter &rewriter)
const override {
824 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
825 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
826 rewriter.setInsertionPoint(op);
827 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
828 adaptor.getMbarId(), rewriter);
829 Type resType = op.getMbarrierPointer().getType();
830 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
836struct NVGPUMBarrierInitLowering
837 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
841 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
842 ConversionPatternRewriter &rewriter)
const override {
843 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
844 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
845 rewriter.setInsertionPoint(op);
846 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
847 adaptor.getMbarId(), rewriter);
849 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
850 adaptor.getPredicate());
856struct NVGPUMBarrierArriveLowering
857 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
858 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
860 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter)
const override {
862 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
864 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
865 adaptor.getMbarId(), rewriter);
866 Type tokenType = getTypeConverter()->convertType(
867 nvgpu::MBarrierTokenType::get(op->getContext()));
868 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
875struct NVGPUMBarrierArriveNoCompleteLowering
876 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877 using MBarrierBasePattern<
878 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
880 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const override {
882 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
884 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
885 adaptor.getMbarId(), rewriter);
886 Type tokenType = getTypeConverter()->convertType(
887 nvgpu::MBarrierTokenType::get(op->getContext()));
889 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
890 op, tokenType, barrier, count);
896struct NVGPUMBarrierTestWaitLowering
897 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
898 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
900 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter)
const override {
902 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
904 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
905 adaptor.getMbarId(), rewriter);
906 Type retType = rewriter.getI1Type();
907 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
913struct NVGPUMBarrierArriveExpectTxLowering
914 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
915 using MBarrierBasePattern<
916 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
918 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
920 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
922 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
923 adaptor.getMbarId(), rewriter);
924 Value txcount =
truncToI32(
b, adaptor.getTxcount());
925 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
928 NVVM::MemScopeKind::CTA,
930 adaptor.getPredicate());
935struct NVGPUMBarrierTryWaitParityLowering
936 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
937 using MBarrierBasePattern<
938 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
940 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter)
const override {
942 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
944 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
945 adaptor.getMbarId(), rewriter);
948 LLVM::ZExtOp::create(
b,
b.getI32Type(), adaptor.getPhaseParity());
949 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
955struct NVGPUTmaAsyncLoadOpLowering
956 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
957 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
959 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
960 ConversionPatternRewriter &rewriter)
const override {
961 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
962 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
964 adaptor.getDst(), {});
968 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
970 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
971 dest = LLVM::AddrSpaceCastOp::create(
b, ptrSharedClusterType, dest);
974 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
975 adaptor.getMbarId(), rewriter);
977 SmallVector<Value> coords = adaptor.getCoordinates();
978 for (
auto [index, value] : llvm::enumerate(coords)) {
983 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
984 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
985 ValueRange{}, adaptor.getMulticastMask(), Value{},
986 NVVM::TMALoadMode::TILE,
989 adaptor.getPredicate());
994struct NVGPUTmaAsyncStoreOpLowering
995 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
996 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
998 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
999 ConversionPatternRewriter &rewriter)
const override {
1000 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1001 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1003 adaptor.getSrc(), {});
1004 SmallVector<Value> coords = adaptor.getCoordinates();
1005 for (
auto [index, value] : llvm::enumerate(coords)) {
1010 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1011 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1012 NVVM::TMAStoreMode::TILE,
1013 adaptor.getPredicate());
1018struct NVGPUGenerateWarpgroupDescriptorLowering
1020 using ConvertOpToLLVMPattern<
1021 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1024 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1025 ConversionPatternRewriter &rewriter)
const override {
1027 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1029 nvgpu::TensorMapSwizzleKind swizzleKind =
1030 op.getTensorMap().getType().getSwizzle();
1033 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1034 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1035 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1038 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1039 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1040 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1043 auto ti64 =
b.getIntegerType(64);
1044 auto makeConst = [&](uint64_t index) -> Value {
1045 return LLVM::ConstantOp::create(
b, ti64,
b.getI64IntegerAttr(index));
1047 auto shiftLeft = [&](Value value,
unsigned shift) -> Value {
1048 return LLVM::ShlOp::create(
b, ti64, value, makeConst(shift));
1050 auto shiftRight = [&](Value value,
unsigned shift) -> Value {
1051 return LLVM::LShrOp::create(
b, ti64, value, makeConst(shift));
1053 auto insertBit = [&](Value desc, Value val,
int startBit) {
1054 return LLVM::OrOp::create(
b, ti64, desc, shiftLeft(val, startBit));
1057 int64_t sizeN = op.getTensorMap().
getType().getTensor().getDimSize(0);
1058 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1059 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1060 uint64_t offsetVal = 0;
1062 Value strideDim = makeConst(strideDimVal);
1063 Value leadDim = makeConst(leadDimVal);
1066 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1067 adaptor.getTensor(), {});
1068 Value basePtr = LLVM::PtrToIntOp::create(
b, ti64, baseAddr);
1070 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1072 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1073 startLeadBit = 16, startBaseAddrBit = 0;
1074 Value dsc = makeConst(0);
1076 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1078 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1080 dsc = insertBit(dsc, strideDim, startStrideBit);
1082 dsc = insertBit(dsc, leadDim, startLeadBit);
1084 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1086 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1087 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1088 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1089 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1090 <<
")\n start_addr : " << baseAddr;
1092 rewriter.replaceOp(op, dsc);
1098 return LLVM::ConstantOp::create(
b,
b.getIntegerType(64),
1099 b.getI32IntegerAttr(
index));
1106 enum CUtensorMapDataTypeEnum {
1107 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1108 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1109 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1110 CU_TENSOR_MAP_DATA_TYPE_INT32,
1111 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1112 CU_TENSOR_MAP_DATA_TYPE_INT64,
1113 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1114 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1115 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1116 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1117 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1118 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1119 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1123 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1125 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1127 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1129 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1131 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1133 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1135 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1137 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1139 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1141 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1143 llvm_unreachable(
"Not supported data type");
1146struct NVGPUTmaCreateDescriptorOpLowering
1148 using ConvertOpToLLVMPattern<
1149 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1151 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1152 ConversionPatternRewriter &rewriter)
const override {
1153 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1154 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1155 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1157 Value tensorElementType =
1158 elementTypeAsLLVMConstant(
b, op.getTensor().getType().getElementType());
1159 auto promotedOperands = getTypeConverter()->promoteOperands(
1160 b.getLoc(), op->getOperands(), adaptor.getOperands(),
b);
1162 Value boxArrayPtr = LLVM::AllocaOp::create(
1163 b, llvmPointerType, llvmInt64Type, makeI64Const(
b, 5));
1164 for (
auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1165 Value gep = LLVM::GEPOp::create(
b, llvmPointerType, llvmPointerType,
1166 boxArrayPtr, makeI64Const(
b, index));
1167 LLVM::StoreOp::create(
b, value, gep);
1170 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1172 SmallVector<Value> arguments;
1173 arguments.push_back(promotedOperands[0]);
1174 arguments.push_back(promotedOperands[1]);
1175 arguments.push_back(tensorElementType);
1176 arguments.push_back(
1177 makeI64Const(
b, (
int)desc.getInterleave()));
1178 arguments.push_back(makeI64Const(
b, (
int)desc.getSwizzle()));
1179 arguments.push_back(makeI64Const(
b, (
int)desc.getL2promo()));
1180 arguments.push_back(makeI64Const(
b, (
int)desc.getOob()));
1181 arguments.push_back(boxArrayPtr);
1184 SmallVector<Type> argTypes = {
1194 FunctionCallBuilder hostRegisterCallBuilder = {
1195 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1197 hostRegisterCallBuilder.
create(
b.getLoc(),
b, arguments).getResult();
1199 rewriter.replaceOp(op, tensorMap);
1204struct NVGPUWarpgroupMmaOpLowering
1206 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1228 class WarpgroupGemm {
1229 nvgpu::WarpgroupMmaOp op;
1230 ImplicitLocOpBuilder b;
1234 int64_t totalM, totalN, totalK;
1237 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1240 int iterationM = 0, iterationN = 0, iterationK = 0;
1245 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1248 if (inputElemType.
isTF32()) {
1250 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1252 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1255 }
else if (inputElemType.
isInteger(1)) {
1258 llvm_unreachable(
"msg: not supported K shape");
1260 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1261 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1265 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1266 bool useF32 =
false)
const {
1267 auto getWgmmaType = [=](Type elemType) {
1269 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1270 if (elemType.
isF16())
1271 return NVVM::WGMMATypes::f16;
1273 return NVVM::WGMMATypes::bf16;
1274 if (isa<Float8E4M3FNType>(elemType))
1275 return NVVM::WGMMATypes::e4m3;
1276 if (isa<Float8E5M2Type>(elemType))
1277 return NVVM::WGMMATypes::e5m2;
1279 return NVVM::WGMMATypes::b1;
1281 return NVVM::WGMMATypes::s8;
1283 return NVVM::WGMMATypes::u8;
1285 return NVVM::WGMMATypes::s32;
1286 llvm_unreachable(
"unsupported type");
1288 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1293 generateWgmmaLayout(std::optional<bool> transpose)
const {
1294 if (transpose.value_or(
false))
1295 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1296 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1300 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1301 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1305 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1306 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1307 NVVM::WGMMAScaleOut::one);
1310 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1311 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1312 NVVM::WGMMAScaleIn::one);
1316 Value makeAdd(Value
lhs, Value
rhs) {
1317 return LLVM::AddOp::create(b,
lhs.getType(),
lhs,
rhs);
1338 Value iterateDescriptorA(Value desc,
int i,
int j,
int k) {
1339 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1340 Type elemA = matrixTypeA.getElementType();
1342 int tileShapeA = matrixTypeA.getDimSize(1);
1343 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1345 LDBG() <<
"\t\t[m: " << i <<
" n: " << j <<
" k: " << k
1346 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1350 return makeAdd(desc, makeI64Const(b, incrementVal));
1364 Value iterateDescriptorB(Value desc,
int i,
int j,
int k) {
1365 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1366 Type elemB = matrixTypeB.getElementType();
1368 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1370 LDBG() <<
"Descriptor B + " << incrementVal;
1373 return makeAdd(desc, makeI64Const(b, incrementVal));
1378 Value generateWgmma(
int i,
int j,
int k, Value matrixC) {
1379 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1380 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1381 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1382 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1383 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1384 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1386 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1387 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1389 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1390 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1392 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1393 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1395 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1396 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1398 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1399 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1400 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1401 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1402 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1404 auto overflow = NVVM::MMAIntOverflowAttr::get(
1405 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1407 return NVVM::WgmmaMmaAsyncOp::create(
1408 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1409 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1414 Value generateWgmmaGroup() {
1416 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1419 SmallVector<Value> wgmmaResults;
1420 for (
int i = 0; i < iterationM; ++i) {
1422 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1423 for (
int j = 0; j < iterationN; ++j)
1424 for (
int k = 0; k < iterationK; ++k)
1425 matrixC = generateWgmma(i, j, k, matrixC);
1426 wgmmaResults.push_back(matrixC);
1428 for (
auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1429 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1430 wgmmaResult, matrix, idx);
1436 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1438 : op(op), b(b), adaptor(adaptor) {
1440 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1441 totalN = op.getDescriptorB().
getType().getTensor().getDimSize(1);
1442 totalK = op.getDescriptorA().
getType().getTensor().getDimSize(1);
1443 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1444 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1450 op.getDescriptorA().getType().getTensor().getElementType());
1453 iterationM = totalM / wgmmaM;
1454 iterationN = totalN / wgmmaN;
1455 iterationK = totalK / wgmmaK;
1463 Value generateWarpgroupMma() {
1464 NVVM::WgmmaFenceAlignedOp::create(b);
1465 Value wgmmaResult = generateWgmmaGroup();
1466 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1467 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1472 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1473 ConversionPatternRewriter &rewriter)
const override {
1474 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1477 WarpgroupGemm warpgroupGemm(op,
b, adaptor);
1480 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1483 rewriter.replaceOp(op, wgmmaResult);
1488struct NVGPUWarpgroupMmaStoreOpLowering
1490 using ConvertOpToLLVMPattern<
1491 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1529 void storeFragmentedMatrix(ImplicitLocOpBuilder &
b, Value matrixD,
1532 Type i32 =
b.getI32Type();
1534 auto makeConst = [&](int32_t index) -> Value {
1535 return LLVM::ConstantOp::create(
b, i32,
b.getI32IntegerAttr(index));
1537 Value c1 = makeConst(1);
1538 Value c2 = makeConst(2);
1539 Value c4 = makeConst(4);
1540 Value c8 = makeConst(8);
1541 Value c16 = makeConst(16);
1544 auto makeMul = [&](Value
lhs, Value
rhs) -> Value {
1545 return LLVM::MulOp::create(
b,
lhs.getType(),
lhs,
rhs);
1547 auto makeAdd = [&](Value
lhs, Value
rhs) -> Value {
1548 return LLVM::AddOp::create(
b,
lhs.getType(),
lhs,
rhs);
1551 auto makeExtractAndStore = [&](
int i, Value wgmmaResult, Value x, Value y,
1553 Type it =
b.getIndexType();
1554 Value idx = arith::IndexCastOp::create(
b, it, x);
1555 Value idy0 = arith::IndexCastOp::create(
b, it, y);
1556 Value idy1 = arith::IndexCastOp::create(
b, it, makeAdd(y, c1));
1557 Value d0 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i);
1558 Value d1 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i + 1);
1559 memref::StoreOp::create(
b, d0, memref,
ValueRange{idx, idy0});
1560 memref::StoreOp::create(
b, d1, memref,
ValueRange{idx, idy1});
1563 Value tidx = NVVM::ThreadIdXOp::create(
b, i32);
1564 Value laneId = LLVM::URemOp::create(
b, i32, tidx, warpSize);
1565 Value warpId = LLVM::UDivOp::create(
b, i32, tidx, warpSize);
1566 Value lane4Id = LLVM::UDivOp::create(
b, i32, laneId, c4);
1567 Value lane4modId = LLVM::URemOp::create(
b, i32, laneId, c4);
1569 Value tj = makeMul(lane4modId, c2);
1570 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1572 ti = makeAdd(ti, makeConst(offset));
1574 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1577 constexpr unsigned numAdjacentRegisters = 2;
1579 constexpr unsigned numStackedMatrices = 2;
1581 size_t storeCount = (structType.getBody().size() /
1582 (numStackedMatrices * numAdjacentRegisters));
1584 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1585 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1586 for (
size_t j = 0; j < storeCount; ++j) {
1587 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1588 size_t structIndex = (i * numAdjacentRegisters) +
1589 (j * (numStackedMatrices * numAdjacentRegisters));
1590 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1596 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1597 ConversionPatternRewriter &rewriter)
const override {
1599 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1600 Value matriDValue = adaptor.getMatrixD();
1601 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1602 for (
auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1603 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1604 Value innerStructValue =
1605 LLVM::ExtractValueOp::create(
b, matriDValue, idx);
1606 storeFragmentedMatrix(
b, innerStructValue, op.getDstMemref(), offset);
1607 offset += structType.getBody().size();
1609 rewriter.eraseOp(op);
1614struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1616 using ConvertOpToLLVMPattern<
1617 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1619 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1620 ConversionPatternRewriter &rewriter)
const override {
1621 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1622 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1623 getTypeConverter()->convertType(op.getMatrixC().getType()));
1624 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1627 Value zero = LLVM::ConstantOp::create(
b, elemType,
b.getZeroAttr(elemType));
1628 Value packStruct = LLVM::PoisonOp::create(
b, packStructType);
1629 SmallVector<Value> innerStructs;
1631 for (
auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1632 auto structType = cast<LLVM::LLVMStructType>(s);
1633 Value structValue = LLVM::ExtractValueOp::create(
b, packStruct, idx);
1634 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1635 structValue = LLVM::InsertValueOp::create(
b, structType, structValue,
1636 zero, ArrayRef<int64_t>({i}));
1638 innerStructs.push_back(structValue);
1641 for (
auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1642 packStruct = LLVM::InsertValueOp::create(
b, packStruct.
getType(),
1643 packStruct, matrix, idx);
1645 rewriter.replaceOp(op, packStruct);
1650struct NVGPUTmaFenceOpLowering
1652 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1654 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1655 ConversionPatternRewriter &rewriter)
const override {
1656 MLIRContext *ctx = op.getContext();
1657 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1658 auto i32Ty =
b.getI32Type();
1659 Value tensormapSize =
1660 LLVM::ConstantOp::create(
b, i32Ty, rewriter.getI32IntegerAttr(128));
1663 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1665 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1666 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1672struct NVGPUTmaPrefetchOpLowering
1674 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1676 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1677 ConversionPatternRewriter &rewriter)
const override {
1678 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1679 op,
nullptr,
nullptr,
1680 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1681 mlir::UnitAttr::get(op.getContext()));
1687 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1689 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1690 ConversionPatternRewriter &rewriter)
const override {
1691 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1692 auto i64Ty =
b.getI64Type();
1693 auto f32Ty =
b.getF32Type();
1694 VectorType inTy = op.getIn().getType();
1696 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1697 Value ret1DVec = LLVM::PoisonOp::create(
b, llvm1DVectorTy);
1698 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1699 for (
int i = 0; i < numElems; i++) {
1700 Value idx = LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(i));
1701 Value elem = LLVM::ExtractElementOp::create(
b, inVec, idx);
1702 Value dst = NVVM::RcpApproxFtzF32Op::create(
b, f32Ty, elem);
1703 ret1DVec = LLVM::InsertElementOp::create(
b, ret1DVec, dst, idx);
1707 if (inTy.getRank() == 1) {
1708 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1712 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1713 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
1714 OpAdaptor adaptor(operands);
1715 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1725 NVGPUMBarrierCreateLowering,
1726 NVGPUMBarrierInitLowering,
1727 NVGPUMBarrierGetLowering,
1728 NVGPUMBarrierArriveLowering,
1729 NVGPUMBarrierArriveNoCompleteLowering,
1730 NVGPUMBarrierTestWaitLowering,
1731 NVGPUMBarrierTryWaitParityLowering,
1732 NVGPUTmaAsyncLoadOpLowering,
1733 NVGPUTmaAsyncStoreOpLowering,
1734 NVGPUTmaCreateDescriptorOpLowering,
1735 NVGPUTmaPrefetchOpLowering,
1736 NVGPUTmaFenceOpLowering,
1737 NVGPUMBarrierArriveExpectTxLowering,
1738 NVGPUGenerateWarpgroupDescriptorLowering,
1739 NVGPUWarpgroupMmaOpLowering,
1740 NVGPUWarpgroupMmaStoreOpLowering,
1741 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1742 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1743 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1744 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const