28#include "llvm/Support/Debug.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
34#define DEBUG_TYPE "nvgpu-to-nvvm"
37#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38#include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(
b,
b.getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
62 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
63 auto i32Ty = IntegerType::get(ctx, 32);
64 auto i32x2Ty = VectorType::get(2, i32Ty);
65 Type f64Ty = Float64Type::get(ctx);
66 Type f64x2Ty = VectorType::get(2, f64Ty);
67 Type f32Ty = Float32Type::get(ctx);
68 Type f32x2Ty = VectorType::get(2, f32Ty);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
86 if (a.getElementType() == VectorType::get(1, f32Ty)) {
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
108 Type i32x2Ty = VectorType::get(2, i32Ty);
109 Type f64x2Ty = VectorType::get(2, f64Ty);
110 Type f32x2Ty = VectorType::get(2, f32Ty);
111 Type f32x1Ty = VectorType::get(1, f32Ty);
113 auto makeConst = [&](int32_t
index) ->
Value {
114 return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32),
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
159 for (
const auto &el : llvm::enumerate(elements)) {
160 result = LLVM::InsertValueOp::create(rewriter, loc,
result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
178 Type i32Ty =
b.getI32Type();
179 Type f64Ty =
b.getF64Type();
180 Type f32Ty =
b.getF32Type();
181 Type i64Ty =
b.getI64Type();
182 Type i8x4Ty = VectorType::get(4,
b.getI8Type());
183 Type i4x8Ty = VectorType::get(8,
b.getIntegerType(4));
184 Type f32x1Ty = VectorType::get(1, f32Ty);
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
187 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(
b, operand, i);
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(
b, i32Ty, toUse));
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
211 LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(idx))));
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
228 nvgpu::MBarrierGroupType barrierType) {
232 IntegerAttr::get(IntegerType::get(context, 64),
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
241 nvgpu::MBarrierGroupType barrierType) {
243 MemRefLayoutAttrInterface layout;
244 return MemRefType::get({barrierType.getNumBarriers()},
245 IntegerType::get(context, 64), layout, memorySpace);
251 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter)
const override {
257 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
269 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
270 vectorResultType.getElementType());
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
277 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279 ldMatrixResultType = rewriter.getI32Type();
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 adaptor.getSrcMemref(), adaptor.getIndices());
286 auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 shape, NVVM::LdStMatrixEltType::B16);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value
result = LLVM::PoisonOp::create(
b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(
b, ldMatrixResult, i)
304 Value casted = LLVM::BitcastOp::create(
b, innerVectorType, i32Register);
308 rewriter.replaceOp(op,
result);
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override {
336 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
368 SmallVector<Value> matA =
370 SmallVector<Value> matB =
372 SmallVector<Value> matC =
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(
b, intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 void runOnOperation()
override {
409 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
410 return converter.convertType(IntegerType::get(type.getContext(), 32));
412 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
413 Type elemType = type.getFragmented().getElementType();
414 int64_t sizeM = type.getFragmented().getDimSize(0);
415 int64_t sizeN = type.getFragmented().getDimSize(1);
419 numMembers = sizeN / 2;
420 else if (elemType.
isF16())
421 numMembers = sizeN / 4;
423 llvm_unreachable(
"unsupported type for warpgroup accumulator");
425 SmallVector<Type> innerStructBody;
426 for (
unsigned i = 0; i < numMembers; i++)
427 innerStructBody.push_back(elemType);
428 auto innerStructType =
429 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
431 SmallVector<Type> structBody;
433 structBody.push_back(innerStructType);
436 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
437 return converter.convertType(convertedType);
439 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
440 return converter.convertType(IntegerType::get(type.getContext(), 64));
442 converter.addConversion(
443 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
444 return converter.convertType(IntegerType::get(type.getContext(), 64));
446 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
447 return converter.convertType(
450 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
451 return LLVM::LLVMPointerType::get(type.getContext());
455 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
456 target.addLegalDialect<::mlir::arith::ArithDialect>();
457 target.addLegalDialect<::mlir::memref::MemRefDialect>();
458 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
461 if (
failed(applyPartialConversion(getOperation(),
target,
468static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
472 llvm::raw_string_ostream ss(str);
473 for (
unsigned i = 0; i < matCSize; i++)
475 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
487static std::string buildMmaSparseAsmString(
488 const std::array<int64_t, 3> &
shape,
unsigned matASize,
unsigned matBSize,
489 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
490 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
491 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
492 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
493 return NVVM::stringifyMMATypes(ptxType);
497 llvm::raw_string_ostream ss(asmStr);
498 ss <<
"mma.sp.sync.aligned.m" <<
shape[0] <<
"n" <<
shape[1] <<
"k"
499 <<
shape[2] <<
".row.col.";
502 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
504 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
505 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
506 unsigned asmArgIdx = 0;
510 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
512 for (
unsigned i = 0; i < arrSize; i++)
513 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
516 ss <<
"$" << asmArgIdx++ <<
",";
517 assert(metaDataSelector <= 1);
518 ss <<
"0x" << metaDataSelector <<
";";
524static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
526 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
527 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
529 int64_t metadataSelector,
const std::array<int64_t, 3> &
shape,
530 Type intrinsicResultType) {
531 auto asmDialectAttr =
532 LLVM::AsmDialectAttr::get(
b.getContext(), LLVM::AsmDialect::AD_ATT);
534 const unsigned matASize = unpackedAData.size();
535 const unsigned matBSize = unpackedB.size();
536 const unsigned matCSize = unpackedC.size();
538 std::string asmStr = buildMmaSparseAsmString(
539 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
540 ptxTypeD, overflow, metadataSelector);
541 std::string constraintStr =
542 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
545 asmVals.reserve(matASize + matBSize + matCSize + 1);
547 llvm::append_range(asmVals, args);
548 asmVals.push_back(indexData);
550 return LLVM::InlineAsmOp::create(
b,
557 LLVM::TailCallKind::None,
563struct NVGPUMmaSparseSyncLowering
565 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
568 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
569 ConversionPatternRewriter &rewriter)
const override {
570 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
573 VectorType aType = op.getMatrixA().getType();
574 VectorType bType = op.getMatrixB().getType();
575 VectorType cType = op.getMatrixC().getType();
577 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
579 return op->emitOpError(
"failed to deduce operand PTX types");
580 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
582 return op->emitOpError(
"failed to deduce operand PTX types");
583 std::optional<NVVM::MMATypes> ptxTypeC =
584 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
587 return op->emitError(
588 "could not infer the PTX type for the accumulator/result");
591 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
592 if (aType.getElementType().isF32() && !tf32Enabled)
596 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
597 if (isa<IntegerType>(aType.getElementType()))
598 overflow = NVVM::MMAIntOverflow::satfinite;
600 SmallVector<Value> matA =
602 SmallVector<Value> matB =
604 SmallVector<Value> matC =
607 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
609 typeConverter->convertType(op->getResultTypes()[0]));
612 Value sparseMetadata = adaptor.getSparseMetadata();
613 if (sparseMetadata.
getType() != VectorType::get(2, rewriter.getI16Type()))
614 return op->emitOpError() <<
"Expected metadata type to be LLVM "
615 "VectorType of 2 i16 elements";
617 LLVM::BitcastOp::create(
b, rewriter.getI32Type(), sparseMetadata);
619 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
620 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
621 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
623 if (
failed(intrinsicResult))
626 assert((*intrinsicResult).getNumResults() == 1 &&
627 "expected inline asm op returns a single LLVM struct type");
630 (*intrinsicResult)->getResult(0), rewriter));
635struct NVGPUAsyncCopyLowering
637 using ConvertOpToLLVMPattern<
638 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
641 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter)
const override {
643 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
644 Location loc = op.getLoc();
645 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
648 adaptor.getDst(), adaptor.getDstIndices());
649 FailureOr<unsigned> dstAddressSpace =
650 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
651 if (
failed(dstAddressSpace))
652 return rewriter.notifyMatchFailure(
653 loc,
"destination memref address space not convertible to integer");
655 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
656 FailureOr<unsigned> srcAddressSpace =
657 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
658 if (
failed(srcAddressSpace))
659 return rewriter.notifyMatchFailure(
660 loc,
"source memref address space not convertible to integer");
664 adaptor.getSrcIndices());
666 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
667 op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
668 scrPtr = LLVM::AddrSpaceCastOp::create(
b, srcPointerGlobalType, scrPtr);
669 int64_t dstElements = adaptor.getDstElements().getZExtValue();
670 int64_t sizeInBytes =
671 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
676 Value srcBytes = adaptor.getSrcElements();
683 LLVM::ConstantOp::create(
b,
b.getI32Type(),
b.getI32IntegerAttr(3));
684 Value bitwidth = LLVM::ConstantOp::create(
686 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
687 Value srcElementsI32 = LLVM::TruncOp::create(
b,
b.getI32Type(), srcBytes);
688 srcBytes = LLVM::LShrOp::create(
689 b, LLVM::MulOp::create(
b, bitwidth, srcElementsI32), c3I32);
693 NVVM::LoadCacheModifierKind cacheModifier =
694 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
695 ? NVVM::LoadCacheModifierKind::CG
696 : NVVM::LoadCacheModifierKind::CA;
698 NVVM::CpAsyncOp::create(
699 b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
700 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
705 LLVM::ConstantOp::create(
b, IntegerType::get(op.getContext(), 32),
706 rewriter.getI32IntegerAttr(0));
707 rewriter.replaceOp(op, zero);
712struct NVGPUAsyncCreateGroupLowering
714 using ConvertOpToLLVMPattern<
715 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
718 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter)
const override {
720 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
722 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
723 IntegerType::get(op.getContext(), 32),
724 rewriter.getI32IntegerAttr(0));
725 rewriter.replaceOp(op, zero);
730struct NVGPUAsyncWaitLowering
732 using ConvertOpToLLVMPattern<
733 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
736 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter)
const override {
739 int32_t numGroups = adaptor.getNumGroups().value_or(0);
740 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
741 rewriter.eraseOp(op);
747struct NVGPUMBarrierCreateLowering
749 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
751 template <
typename moduleT>
752 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
753 Operation *funcOp, moduleT moduleOp,
754 MemRefType barrierType)
const {
755 SymbolTable symbolTable(moduleOp);
756 OpBuilder::InsertionGuard guard(rewriter);
757 rewriter.setInsertionPoint(&moduleOp.front());
758 auto global = memref::GlobalOp::create(
759 rewriter, funcOp->
getLoc(),
"__mbarrier",
760 rewriter.getStringAttr(
"private"),
764 rewriter.getI64IntegerAttr(8));
765 symbolTable.insert(global);
770 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
771 ConversionPatternRewriter &rewriter)
const override {
774 rewriter.getContext(), op.getBarriers().getType());
776 memref::GlobalOp global;
778 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
780 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
782 rewriter.setInsertionPoint(op);
783 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
790template <
typename SourceOp>
793 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
795 Value getMbarrierPtr(ImplicitLocOpBuilder &
b,
796 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
798 ConversionPatternRewriter &rewriter)
const {
799 MemRefType mbarrierMemrefType =
802 rewriter,
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
806struct NVGPUMBarrierGetLowering
807 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
808 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
811 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter)
const override {
813 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
814 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
815 rewriter.setInsertionPoint(op);
816 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
817 adaptor.getMbarId(), rewriter);
818 Type resType = op.getMbarrierPointer().getType();
819 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
825struct NVGPUMBarrierInitLowering
826 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
827 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
830 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
831 ConversionPatternRewriter &rewriter)
const override {
832 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
833 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
834 rewriter.setInsertionPoint(op);
835 Value barrier = getMbarrierPtr(
b, mbarrierType, adaptor.getBarriers(),
836 adaptor.getMbarId(), rewriter);
838 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
839 adaptor.getPredicate());
845struct NVGPUMBarrierArriveLowering
846 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
847 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
849 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
850 ConversionPatternRewriter &rewriter)
const override {
851 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
853 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
854 adaptor.getMbarId(), rewriter);
855 Type tokenType = getTypeConverter()->convertType(
856 nvgpu::MBarrierTokenType::get(op->getContext()));
857 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
864struct NVGPUMBarrierArriveNoCompleteLowering
865 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
866 using MBarrierBasePattern<
867 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
869 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
870 ConversionPatternRewriter &rewriter)
const override {
871 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
873 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
874 adaptor.getMbarId(), rewriter);
875 Type tokenType = getTypeConverter()->convertType(
876 nvgpu::MBarrierTokenType::get(op->getContext()));
878 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
879 op, tokenType, barrier, count);
885struct NVGPUMBarrierTestWaitLowering
886 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
887 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
889 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter)
const override {
891 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
893 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
894 adaptor.getMbarId(), rewriter);
895 Type retType = rewriter.getI1Type();
896 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
902struct NVGPUMBarrierArriveExpectTxLowering
903 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
904 using MBarrierBasePattern<
905 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
907 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter)
const override {
909 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
911 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
912 adaptor.getMbarId(), rewriter);
913 Value txcount =
truncToI32(
b, adaptor.getTxcount());
914 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
917 NVVM::MemScopeKind::CTA,
919 adaptor.getPredicate());
924struct NVGPUMBarrierTryWaitParityLowering
925 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
926 using MBarrierBasePattern<
927 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
929 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
930 ConversionPatternRewriter &rewriter)
const override {
931 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
933 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
934 adaptor.getMbarId(), rewriter);
937 LLVM::ZExtOp::create(
b,
b.getI32Type(), adaptor.getPhaseParity());
938 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
944struct NVGPUTmaAsyncLoadOpLowering
945 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
946 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
948 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
949 ConversionPatternRewriter &rewriter)
const override {
950 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
951 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
953 adaptor.getDst(), {});
957 auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
959 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
960 dest = LLVM::AddrSpaceCastOp::create(
b, ptrSharedClusterType, dest);
963 getMbarrierPtr(
b, op.getBarriers().getType(), adaptor.getBarriers(),
964 adaptor.getMbarId(), rewriter);
966 SmallVector<Value> coords = adaptor.getCoordinates();
967 for (
auto [index, value] : llvm::enumerate(coords)) {
972 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
973 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
974 ValueRange{}, adaptor.getMulticastMask(), Value{},
975 NVVM::TMALoadMode::TILE,
978 adaptor.getPredicate());
983struct NVGPUTmaAsyncStoreOpLowering
984 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
985 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
987 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
988 ConversionPatternRewriter &rewriter)
const override {
989 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
990 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
992 adaptor.getSrc(), {});
993 SmallVector<Value> coords = adaptor.getCoordinates();
994 for (
auto [index, value] : llvm::enumerate(coords)) {
999 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1000 op, adaptor.getTensorMapDescriptor(), dest, coords, Value{},
1001 NVVM::TMAStoreMode::TILE,
1002 adaptor.getPredicate());
1007struct NVGPUGenerateWarpgroupDescriptorLowering
1009 using ConvertOpToLLVMPattern<
1010 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1013 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1014 ConversionPatternRewriter &rewriter)
const override {
1016 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1018 nvgpu::TensorMapSwizzleKind swizzleKind =
1019 op.getTensorMap().getType().getSwizzle();
1022 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1023 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1024 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1027 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1028 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1029 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1032 auto ti64 =
b.getIntegerType(64);
1033 auto makeConst = [&](uint64_t index) -> Value {
1034 return LLVM::ConstantOp::create(
b, ti64,
b.getI64IntegerAttr(index));
1036 auto shiftLeft = [&](Value value,
unsigned shift) -> Value {
1037 return LLVM::ShlOp::create(
b, ti64, value, makeConst(shift));
1039 auto shiftRight = [&](Value value,
unsigned shift) -> Value {
1040 return LLVM::LShrOp::create(
b, ti64, value, makeConst(shift));
1042 auto insertBit = [&](Value desc, Value val,
int startBit) {
1043 return LLVM::OrOp::create(
b, ti64, desc, shiftLeft(val, startBit));
1046 int64_t sizeN = op.getTensorMap().
getType().getTensor().getDimSize(0);
1047 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1048 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1049 uint64_t offsetVal = 0;
1051 Value strideDim = makeConst(strideDimVal);
1052 Value leadDim = makeConst(leadDimVal);
1055 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1056 adaptor.getTensor(), {});
1057 Value basePtr = LLVM::PtrToIntOp::create(
b, ti64, baseAddr);
1059 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1061 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1062 startLeadBit = 16, startBaseAddrBit = 0;
1063 Value dsc = makeConst(0);
1065 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1067 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1069 dsc = insertBit(dsc, strideDim, startStrideBit);
1071 dsc = insertBit(dsc, leadDim, startLeadBit);
1073 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1075 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1076 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1077 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1078 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1079 <<
")\n start_addr : " << baseAddr;
1081 rewriter.replaceOp(op, dsc);
1087 return LLVM::ConstantOp::create(
b,
b.getIntegerType(64),
1088 b.getI32IntegerAttr(
index));
1095 enum CUtensorMapDataTypeEnum {
1096 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1097 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1098 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1099 CU_TENSOR_MAP_DATA_TYPE_INT32,
1100 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1101 CU_TENSOR_MAP_DATA_TYPE_INT64,
1102 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1103 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1104 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1105 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1106 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1107 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1108 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1112 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1114 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1116 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1118 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1120 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1122 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1124 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1126 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1128 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1130 return makeI64Const(
b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1132 llvm_unreachable(
"Not supported data type");
1135struct NVGPUTmaCreateDescriptorOpLowering
1137 using ConvertOpToLLVMPattern<
1138 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1140 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1141 ConversionPatternRewriter &rewriter)
const override {
1142 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1143 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1144 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1146 Value tensorElementType =
1147 elementTypeAsLLVMConstant(
b, op.getTensor().getType().getElementType());
1148 auto promotedOperands = getTypeConverter()->promoteOperands(
1149 b.getLoc(), op->getOperands(), adaptor.getOperands(),
b);
1151 Value boxArrayPtr = LLVM::AllocaOp::create(
1152 b, llvmPointerType, llvmInt64Type, makeI64Const(
b, 5));
1153 for (
auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1154 Value gep = LLVM::GEPOp::create(
b, llvmPointerType, llvmPointerType,
1155 boxArrayPtr, makeI64Const(
b, index));
1156 LLVM::StoreOp::create(
b, value, gep);
1159 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1161 SmallVector<Value> arguments;
1162 arguments.push_back(promotedOperands[0]);
1163 arguments.push_back(promotedOperands[1]);
1164 arguments.push_back(tensorElementType);
1165 arguments.push_back(
1166 makeI64Const(
b, (
int)desc.getInterleave()));
1167 arguments.push_back(makeI64Const(
b, (
int)desc.getSwizzle()));
1168 arguments.push_back(makeI64Const(
b, (
int)desc.getL2promo()));
1169 arguments.push_back(makeI64Const(
b, (
int)desc.getOob()));
1170 arguments.push_back(boxArrayPtr);
1173 SmallVector<Type> argTypes = {
1183 FunctionCallBuilder hostRegisterCallBuilder = {
1184 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1186 hostRegisterCallBuilder.
create(
b.getLoc(),
b, arguments).getResult();
1188 rewriter.replaceOp(op, tensorMap);
1193struct NVGPUWarpgroupMmaOpLowering
1195 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1217 class WarpgroupGemm {
1218 nvgpu::WarpgroupMmaOp op;
1219 ImplicitLocOpBuilder b;
1223 int64_t totalM, totalN, totalK;
1226 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1229 int iterationM = 0, iterationN = 0, iterationK = 0;
1234 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1237 if (inputElemType.
isTF32()) {
1239 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1241 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1244 }
else if (inputElemType.
isInteger(1)) {
1247 llvm_unreachable(
"msg: not supported K shape");
1249 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1250 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1254 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1255 bool useF32 =
false)
const {
1256 auto getWgmmaType = [=](Type elemType) {
1258 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1259 if (elemType.
isF16())
1260 return NVVM::WGMMATypes::f16;
1262 return NVVM::WGMMATypes::bf16;
1263 if (isa<Float8E4M3FNType>(elemType))
1264 return NVVM::WGMMATypes::e4m3;
1265 if (isa<Float8E5M2Type>(elemType))
1266 return NVVM::WGMMATypes::e5m2;
1268 return NVVM::WGMMATypes::b1;
1270 return NVVM::WGMMATypes::s8;
1272 return NVVM::WGMMATypes::u8;
1274 return NVVM::WGMMATypes::s32;
1275 llvm_unreachable(
"unsupported type");
1277 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1282 generateWgmmaLayout(std::optional<bool> transpose)
const {
1283 if (transpose.value_or(
false))
1284 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1285 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1289 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1290 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1294 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1295 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1296 NVVM::WGMMAScaleOut::one);
1299 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1300 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1301 NVVM::WGMMAScaleIn::one);
1305 Value makeAdd(Value
lhs, Value
rhs) {
1306 return LLVM::AddOp::create(b,
lhs.getType(),
lhs,
rhs);
1327 Value iterateDescriptorA(Value desc,
int i,
int j,
int k) {
1328 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1329 Type elemA = matrixTypeA.getElementType();
1331 int tileShapeA = matrixTypeA.getDimSize(1);
1332 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1334 LDBG() <<
"\t\t[m: " << i <<
" n: " << j <<
" k: " << k
1335 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1339 return makeAdd(desc, makeI64Const(b, incrementVal));
1353 Value iterateDescriptorB(Value desc,
int i,
int j,
int k) {
1354 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1355 Type elemB = matrixTypeB.getElementType();
1357 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1359 LDBG() <<
"Descriptor B + " << incrementVal;
1362 return makeAdd(desc, makeI64Const(b, incrementVal));
1367 Value generateWgmma(
int i,
int j,
int k, Value matrixC) {
1368 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1369 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1370 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1371 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1372 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1373 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1375 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1376 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1378 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1379 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1381 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1382 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1384 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1385 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1387 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1388 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1389 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1390 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1391 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1393 auto overflow = NVVM::MMAIntOverflowAttr::get(
1394 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1396 return NVVM::WgmmaMmaAsyncOp::create(
1397 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1398 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1403 Value generateWgmmaGroup() {
1405 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1408 SmallVector<Value> wgmmaResults;
1409 for (
int i = 0; i < iterationM; ++i) {
1411 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1412 for (
int j = 0; j < iterationN; ++j)
1413 for (
int k = 0; k < iterationK; ++k)
1414 matrixC = generateWgmma(i, j, k, matrixC);
1415 wgmmaResults.push_back(matrixC);
1417 for (
auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
1418 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1419 wgmmaResult, matrix, idx);
1425 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1427 : op(op), b(b), adaptor(adaptor) {
1429 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1430 totalN = op.getDescriptorB().
getType().getTensor().getDimSize(1);
1431 totalK = op.getDescriptorA().
getType().getTensor().getDimSize(1);
1432 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1433 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1439 op.getDescriptorA().getType().getTensor().getElementType());
1442 iterationM = totalM / wgmmaM;
1443 iterationN = totalN / wgmmaN;
1444 iterationK = totalK / wgmmaK;
1452 Value generateWarpgroupMma() {
1453 NVVM::WgmmaFenceAlignedOp::create(b);
1454 Value wgmmaResult = generateWgmmaGroup();
1455 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1456 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1461 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter)
const override {
1463 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1466 WarpgroupGemm warpgroupGemm(op,
b, adaptor);
1469 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1472 rewriter.replaceOp(op, wgmmaResult);
1477struct NVGPUWarpgroupMmaStoreOpLowering
1479 using ConvertOpToLLVMPattern<
1480 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1518 void storeFragmentedMatrix(ImplicitLocOpBuilder &
b, Value matrixD,
1521 Type i32 =
b.getI32Type();
1523 auto makeConst = [&](int32_t index) -> Value {
1524 return LLVM::ConstantOp::create(
b, i32,
b.getI32IntegerAttr(index));
1526 Value c1 = makeConst(1);
1527 Value c2 = makeConst(2);
1528 Value c4 = makeConst(4);
1529 Value c8 = makeConst(8);
1530 Value c16 = makeConst(16);
1533 auto makeMul = [&](Value
lhs, Value
rhs) -> Value {
1534 return LLVM::MulOp::create(
b,
lhs.getType(),
lhs,
rhs);
1536 auto makeAdd = [&](Value
lhs, Value
rhs) -> Value {
1537 return LLVM::AddOp::create(
b,
lhs.getType(),
lhs,
rhs);
1540 auto makeExtractAndStore = [&](
int i, Value wgmmaResult, Value x, Value y,
1542 Type it =
b.getIndexType();
1543 Value idx = arith::IndexCastOp::create(
b, it, x);
1544 Value idy0 = arith::IndexCastOp::create(
b, it, y);
1545 Value idy1 = arith::IndexCastOp::create(
b, it, makeAdd(y, c1));
1546 Value d0 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i);
1547 Value d1 = LLVM::ExtractValueOp::create(
b, wgmmaResult, i + 1);
1548 memref::StoreOp::create(
b, d0, memref,
ValueRange{idx, idy0});
1549 memref::StoreOp::create(
b, d1, memref,
ValueRange{idx, idy1});
1552 Value tidx = NVVM::ThreadIdXOp::create(
b, i32);
1553 Value laneId = LLVM::URemOp::create(
b, i32, tidx, warpSize);
1554 Value warpId = LLVM::UDivOp::create(
b, i32, tidx, warpSize);
1555 Value lane4Id = LLVM::UDivOp::create(
b, i32, laneId, c4);
1556 Value lane4modId = LLVM::URemOp::create(
b, i32, laneId, c4);
1558 Value tj = makeMul(lane4modId, c2);
1559 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1561 ti = makeAdd(ti, makeConst(offset));
1563 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1566 constexpr unsigned numAdjacentRegisters = 2;
1568 constexpr unsigned numStackedMatrices = 2;
1570 size_t storeCount = (structType.getBody().size() /
1571 (numStackedMatrices * numAdjacentRegisters));
1573 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1574 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1575 for (
size_t j = 0; j < storeCount; ++j) {
1576 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1577 size_t structIndex = (i * numAdjacentRegisters) +
1578 (j * (numStackedMatrices * numAdjacentRegisters));
1579 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1585 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1586 ConversionPatternRewriter &rewriter)
const override {
1588 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1589 Value matriDValue = adaptor.getMatrixD();
1590 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1591 for (
auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1592 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1593 Value innerStructValue =
1594 LLVM::ExtractValueOp::create(
b, matriDValue, idx);
1595 storeFragmentedMatrix(
b, innerStructValue, op.getDstMemref(), offset);
1596 offset += structType.getBody().size();
1598 rewriter.eraseOp(op);
1603struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1605 using ConvertOpToLLVMPattern<
1606 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1608 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1609 ConversionPatternRewriter &rewriter)
const override {
1610 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1611 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1612 getTypeConverter()->convertType(op.getMatrixC().getType()));
1613 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1616 Value zero = LLVM::ConstantOp::create(
b, elemType,
b.getZeroAttr(elemType));
1617 Value packStruct = LLVM::PoisonOp::create(
b, packStructType);
1618 SmallVector<Value> innerStructs;
1620 for (
auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1621 auto structType = cast<LLVM::LLVMStructType>(s);
1622 Value structValue = LLVM::ExtractValueOp::create(
b, packStruct, idx);
1623 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1624 structValue = LLVM::InsertValueOp::create(
b, structType, structValue,
1625 zero, ArrayRef<int64_t>({i}));
1627 innerStructs.push_back(structValue);
1630 for (
auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1631 packStruct = LLVM::InsertValueOp::create(
b, packStruct.
getType(),
1632 packStruct, matrix, idx);
1634 rewriter.replaceOp(op, packStruct);
1639struct NVGPUTmaFenceOpLowering
1641 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1643 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1644 ConversionPatternRewriter &rewriter)
const override {
1645 MLIRContext *ctx = op.getContext();
1646 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1647 auto i32Ty =
b.getI32Type();
1648 Value tensormapSize =
1649 LLVM::ConstantOp::create(
b, i32Ty, rewriter.getI32IntegerAttr(128));
1652 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1654 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1655 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1661struct NVGPUTmaPrefetchOpLowering
1663 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1665 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1666 ConversionPatternRewriter &rewriter)
const override {
1667 rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1668 op,
nullptr,
nullptr,
1669 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1670 mlir::UnitAttr::get(op.getContext()));
1676 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1678 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1679 ConversionPatternRewriter &rewriter)
const override {
1680 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1681 auto i64Ty =
b.getI64Type();
1682 auto f32Ty =
b.getF32Type();
1683 VectorType inTy = op.getIn().getType();
1685 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1686 Value ret1DVec = LLVM::PoisonOp::create(
b, llvm1DVectorTy);
1687 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1688 for (
int i = 0; i < numElems; i++) {
1689 Value idx = LLVM::ConstantOp::create(
b, i64Ty,
b.getI64IntegerAttr(i));
1690 Value elem = LLVM::ExtractElementOp::create(
b, inVec, idx);
1691 Value dst = NVVM::RcpApproxFtzF32Op::create(
b, f32Ty, elem);
1692 ret1DVec = LLVM::InsertElementOp::create(
b, ret1DVec, dst, idx);
1696 if (inTy.getRank() == 1) {
1697 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1701 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1702 [&](Type llvm1DVectorTy,
ValueRange operands) -> Value {
1703 OpAdaptor adaptor(operands);
1704 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1718 typeConverter, [](gpu::AddressSpace space) ->
unsigned {
1720 case gpu::AddressSpace::Global:
1721 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
1722 case gpu::AddressSpace::Workgroup:
1723 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
1724 case gpu::AddressSpace::Private:
1727 llvm_unreachable(
"unknown address space enum value");
1734 NVGPUMBarrierCreateLowering,
1735 NVGPUMBarrierInitLowering,
1736 NVGPUMBarrierGetLowering,
1737 NVGPUMBarrierArriveLowering,
1738 NVGPUMBarrierArriveNoCompleteLowering,
1739 NVGPUMBarrierTestWaitLowering,
1740 NVGPUMBarrierTryWaitParityLowering,
1741 NVGPUTmaAsyncLoadOpLowering,
1742 NVGPUTmaAsyncStoreOpLowering,
1743 NVGPUTmaCreateDescriptorOpLowering,
1744 NVGPUTmaPrefetchOpLowering,
1745 NVGPUTmaFenceOpLowering,
1746 NVGPUMBarrierArriveExpectTxLowering,
1747 NVGPUGenerateWarpgroupDescriptorLowering,
1748 NVGPUWarpgroupMmaOpLowering,
1749 NVGPUWarpgroupMmaStoreOpLowering,
1750 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1751 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1752 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1753 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
MLIRContext * getContext() const
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const