29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
53 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
63 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
71 if (a.getElementType() == f16x2Ty) {
72 return LLVM::LLVMStructType::getLiteral(
75 if (a.getElementType() == i32x2Ty) {
76 return LLVM::LLVMStructType::getLiteral(
80 if (a.getElementType() == f64x2Ty) {
81 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
83 if (a.getElementType() == f32x2Ty) {
84 return LLVM::LLVMStructType::getLiteral(
89 return LLVM::LLVMStructType::getLiteral(
92 return vectorResultType;
104 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
115 auto makeConst = [&](int32_t index) ->
Value {
125 if (arrayType.getElementType() == f16x2Ty ||
126 arrayType.getElementType() == f32x1Ty) {
127 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
129 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
131 loc, arrayType.getElementType(), el);
132 elements.push_back(el);
140 if (arrayType.getElementType() == i32x2Ty ||
141 arrayType.getElementType() == f64x2Ty ||
142 arrayType.getElementType() == f32x2Ty) {
144 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
146 rewriter.
create<LLVM::PoisonOp>(loc, arrayType.getElementType());
148 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149 Value x2 = rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult,
151 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
153 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
155 elements.push_back(vec);
160 Value result = rewriter.
create<LLVM::PoisonOp>(loc, arrayType);
162 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, el.value(),
168 return intrinsicResult;
178 NVVM::MMATypes operandPtxType) {
187 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
189 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190 Value toUse = b.
create<LLVM::ExtractValueOp>(operand, i);
194 if (arrayTy.getElementType() == i8x4Ty ||
195 arrayTy.getElementType() == i4x8Ty ||
196 (arrayTy.getElementType() == f32x1Ty &&
197 operandPtxType == NVVM::MMATypes::tf32)) {
198 result.push_back(b.
create<LLVM::BitcastOp>(i32Ty, toUse));
205 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207 innerArrayTy.getElementType() == f64Ty ||
208 innerArrayTy.getElementType() == f32Ty)) {
209 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210 idx < innerSize; idx++) {
211 result.push_back(b.
create<LLVM::ExtractElementOp>(
217 result.push_back(toUse);
224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225 barrierType.getMemorySpace()));
230 nvgpu::MBarrierGroupType barrierType) {
235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
243 nvgpu::MBarrierGroupType barrierType) {
245 MemRefLayoutAttrInterface layout;
256 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
267 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268 if (!vectorResultType) {
272 vectorResultType.getElementType());
274 int64_t num32BitRegs = vectorResultType.getDimSize(0);
276 Type ldMatrixResultType;
277 if (num32BitRegs > 1) {
278 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
286 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287 adaptor.getIndices(), rewriter);
288 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289 ldMatrixResultType, srcPtr,
291 op.getTranspose() ? NVVM::MMALayout::col
292 : NVVM::MMALayout::row);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = b.create<LLVM::PoisonOp>(finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
304 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305 result = b.create<LLVM::InsertValueOp>(result, casted, i);
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult = b.create<NVVM::MmaOp>(
379 intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387 NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395 struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
400 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401 arith::ArithDialect>();
404 void runOnOperation()
override {
410 converter, [](gpu::AddressSpace space) ->
unsigned {
412 case gpu::AddressSpace::Global:
413 return static_cast<unsigned>(
415 case gpu::AddressSpace::Workgroup:
416 return static_cast<unsigned>(
418 case gpu::AddressSpace::Private:
421 llvm_unreachable(
"unknown address space enum value");
427 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
430 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
431 Type elemType = type.getFragmented().getElementType();
432 int64_t sizeM = type.getFragmented().getDimSize(0);
433 int64_t sizeN = type.getFragmented().getDimSize(1);
437 numMembers = sizeN / 2;
438 else if (elemType.
isF16())
439 numMembers = sizeN / 4;
441 llvm_unreachable(
"unsupported type for warpgroup accumulator");
444 for (
unsigned i = 0; i < numMembers; i++)
445 innerStructBody.push_back(elemType);
446 auto innerStructType =
447 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
451 structBody.push_back(innerStructType);
454 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455 return converter.convertType(convertedType);
457 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
460 converter.addConversion(
461 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
464 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
465 return converter.convertType(
468 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
473 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474 target.addLegalDialect<::mlir::arith::ArithDialect>();
475 target.addLegalDialect<::mlir::memref::MemRefDialect>();
476 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
486 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
490 llvm::raw_string_ostream ss(str);
491 for (
unsigned i = 0; i < matCSize; i++)
493 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
505 static std::string buildMmaSparseAsmString(
506 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511 return NVVM::stringifyMMATypes(ptxType);
515 llvm::raw_string_ostream ss(asmStr);
516 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
517 << shape[2] <<
".row.col.";
520 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
522 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
523 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
524 unsigned asmArgIdx = 0;
528 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
530 for (
unsigned i = 0; i < arrSize; i++)
531 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
534 ss <<
"$" << asmArgIdx++ <<
",";
535 assert(metaDataSelector <= 1);
536 ss <<
"0x" << metaDataSelector <<
";";
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
544 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
547 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
548 Type intrinsicResultType) {
549 auto asmDialectAttr =
552 const unsigned matASize = unpackedAData.size();
553 const unsigned matBSize = unpackedB.size();
554 const unsigned matCSize = unpackedC.size();
556 std::string asmStr = buildMmaSparseAsmString(
557 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558 ptxTypeD, overflow, metadataSelector);
559 std::string constraintStr =
560 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
563 asmVals.reserve(matASize + matBSize + matCSize + 1);
565 llvm::append_range(asmVals, args);
566 asmVals.push_back(indexData);
568 return b.
create<LLVM::InlineAsmOp>(
580 struct NVGPUMmaSparseSyncLowering
585 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
590 VectorType aType = op.getMatrixA().getType();
591 VectorType bType = op.getMatrixB().getType();
592 VectorType cType = op.getMatrixC().getType();
594 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595 if (failed(ptxTypeA))
596 return op->emitOpError(
"failed to deduce operand PTX types");
597 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598 if (failed(ptxTypeB))
599 return op->emitOpError(
"failed to deduce operand PTX types");
600 std::optional<NVVM::MMATypes> ptxTypeC =
601 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
604 return op->emitError(
605 "could not infer the PTX type for the accumulator/result");
608 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609 if (aType.getElementType().isF32() && !tf32Enabled)
613 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614 if (isa<IntegerType>(aType.getElementType()))
615 overflow = NVVM::MMAIntOverflow::satfinite;
624 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
626 typeConverter->convertType(op->getResultTypes()[0]));
629 Value sparseMetadata = adaptor.getSparseMetadata();
631 return op->emitOpError() <<
"Expected metadata type to be LLVM "
632 "VectorType of 2 i16 elements";
636 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
637 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
638 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640 if (failed(intrinsicResult))
643 assert((*intrinsicResult).getNumResults() == 1 &&
644 "expected inline asm op returns a single LLVM struct type");
647 (*intrinsicResult)->getResult(0), rewriter));
652 struct NVGPUAsyncCopyLowering
658 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
662 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
664 getStridedElementPtr(b.
getLoc(), dstMemrefType, adaptor.getDst(),
665 adaptor.getDstIndices(), rewriter);
666 FailureOr<unsigned> dstAddressSpace =
667 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668 if (failed(dstAddressSpace))
670 loc,
"destination memref address space not convertible to integer");
672 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
673 FailureOr<unsigned> srcAddressSpace =
674 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
675 if (failed(srcAddressSpace))
677 loc,
"source memref address space not convertible to integer");
679 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
680 adaptor.getSrcIndices(), rewriter);
684 scrPtr = b.
create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
685 int64_t dstElements = adaptor.getDstElements().getZExtValue();
686 int64_t sizeInBytes =
687 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
692 Value srcBytes = adaptor.getSrcElements();
704 srcBytes = b.
create<LLVM::LShrOp>(
705 b.
create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
709 NVVM::LoadCacheModifierKind cacheModifier =
710 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
711 ? NVVM::LoadCacheModifierKind::CG
712 : NVVM::LoadCacheModifierKind::CA;
714 b.
create<NVVM::CpAsyncOp>(
727 struct NVGPUAsyncCreateGroupLowering
733 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
735 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
745 struct NVGPUAsyncWaitLowering
751 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
754 int32_t numGroups = adaptor.getNumGroups().value_or(0);
755 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
762 struct NVGPUMBarrierCreateLowering
766 template <
typename moduleT>
769 MemRefType barrierType)
const {
773 auto global = rewriter.
create<memref::GlobalOp>(
774 funcOp->
getLoc(),
"__mbarrier",
780 symbolTable.insert(global);
785 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
789 rewriter.
getContext(), op.getBarriers().getType());
791 memref::GlobalOp global;
793 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
795 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
805 template <
typename SourceOp>
811 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
814 MemRefType mbarrierMemrefType =
817 b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
821 struct NVGPUMBarrierGetLowering
822 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
823 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
826 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
829 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
831 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
832 adaptor.getMbarId(), rewriter);
833 Type resType = op.getMbarrierPointer().getType();
840 struct NVGPUMBarrierInitLowering
841 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
842 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
845 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
848 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
850 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
851 adaptor.getMbarId(), rewriter);
855 op, barrier, count, adaptor.getPredicate());
858 adaptor.getPredicate());
865 struct NVGPUMBarrierArriveLowering
866 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
867 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
869 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
873 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
874 adaptor.getMbarId(), rewriter);
875 Type tokenType = getTypeConverter()->convertType(
890 struct NVGPUMBarrierArriveNoCompleteLowering
891 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
892 using MBarrierBasePattern<
893 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
895 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
899 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
900 adaptor.getMbarId(), rewriter);
901 Type tokenType = getTypeConverter()->convertType(
906 op, tokenType, barrier, count);
909 op, tokenType, barrier, count);
916 struct NVGPUMBarrierTestWaitLowering
917 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
918 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
920 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
924 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
925 adaptor.getMbarId(), rewriter);
929 op, retType, barrier, adaptor.getToken());
932 op, retType, barrier, adaptor.getToken());
938 struct NVGPUMBarrierArriveExpectTxLowering
939 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
940 using MBarrierBasePattern<
941 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
943 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
947 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
948 adaptor.getMbarId(), rewriter);
953 op, barrier, txcount, adaptor.getPredicate());
958 op, barrier, txcount, adaptor.getPredicate());
963 struct NVGPUMBarrierTryWaitParityLowering
964 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
965 using MBarrierBasePattern<
966 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
968 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
972 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
973 adaptor.getMbarId(), rewriter);
980 op, barrier, phase, ticks);
990 struct NVGPUTmaAsyncLoadOpLowering
991 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
992 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
994 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
997 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
998 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
999 adaptor.getDst(), {}, rewriter);
1001 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1002 adaptor.getMbarId(), rewriter);
1009 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1011 adaptor.getPredicate());
1016 struct NVGPUTmaAsyncStoreOpLowering
1017 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1018 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1020 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1023 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1024 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1025 adaptor.getSrc(), {}, rewriter);
1032 op, adaptor.getTensorMapDescriptor(), dest, coords,
1033 adaptor.getPredicate());
1038 struct NVGPUGenerateWarpgroupDescriptorLowering
1044 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1049 nvgpu::TensorMapSwizzleKind swizzleKind =
1050 op.getTensorMap().getType().getSwizzle();
1053 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1054 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1055 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1058 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1059 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1060 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1064 auto makeConst = [&](uint64_t index) ->
Value {
1067 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1068 return b.
create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1070 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1071 return b.
create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1073 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1074 return b.
create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1077 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1078 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1079 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1080 uint64_t offsetVal = 0;
1082 Value strideDim = makeConst(strideDimVal);
1083 Value leadDim = makeConst(leadDimVal);
1085 Value baseAddr = getStridedElementPtr(
1086 op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087 adaptor.getTensor(), {}, rewriter);
1088 Value basePtr = b.
create<LLVM::PtrToIntOp>(ti64, baseAddr);
1090 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1092 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1093 startLeadBit = 16, startBaseAddrBit = 0;
1094 Value dsc = makeConst(0);
1096 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1098 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1100 dsc = insertBit(dsc, strideDim, startStrideBit);
1102 dsc = insertBit(dsc, leadDim, startLeadBit);
1104 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1106 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1107 <<
"leading_off:" << leadDimVal <<
"\t"
1108 <<
"stride_off :" << strideDimVal <<
"\t"
1109 <<
"base_offset:" << offsetVal <<
"\t"
1110 <<
"layout_type:" << swizzle <<
" ("
1111 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1112 <<
")\n start_addr : " << baseAddr <<
"\n");
1128 enum CUtensorMapDataTypeEnum {
1129 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1130 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1131 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1132 CU_TENSOR_MAP_DATA_TYPE_INT32,
1133 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1134 CU_TENSOR_MAP_DATA_TYPE_INT64,
1135 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1136 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1137 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1138 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1139 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1140 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1141 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1145 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1147 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1149 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1151 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1153 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1155 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1157 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1159 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1161 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1163 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1165 llvm_unreachable(
"Not supported data type");
1168 struct NVGPUTmaCreateDescriptorOpLowering
1173 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1179 Value tensorElementType =
1180 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1181 auto promotedOperands = getTypeConverter()->promoteOperands(
1182 b.
getLoc(), op->getOperands(), adaptor.getOperands(), b);
1184 Value boxArrayPtr = b.
create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1185 makeI64Const(b, 5));
1186 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1187 Value gep = b.
create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1188 boxArrayPtr, makeI64Const(b, index));
1189 b.
create<LLVM::StoreOp>(value, gep);
1192 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1195 arguments.push_back(promotedOperands[0]);
1196 arguments.push_back(promotedOperands[1]);
1197 arguments.push_back(tensorElementType);
1198 arguments.push_back(
1199 makeI64Const(b, (
int)desc.getInterleave()));
1200 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1201 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1202 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1203 arguments.push_back(boxArrayPtr);
1217 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1219 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1226 struct NVGPUWarpgroupMmaOpLowering
1250 class WarpgroupGemm {
1251 nvgpu::WarpgroupMmaOp op;
1256 int64_t totalM, totalN, totalK;
1259 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1262 int iterationM = 0, iterationN = 0, iterationK = 0;
1267 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1270 if (inputElemType.
isTF32()) {
1272 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1274 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1277 }
else if (inputElemType.
isInteger(1)) {
1280 llvm_unreachable(
"msg: not supported K shape");
1282 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1283 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1287 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1288 bool useF32 =
false)
const {
1289 auto getWgmmaType = [=](
Type elemType) {
1291 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1292 if (elemType.
isF16())
1293 return NVVM::WGMMATypes::f16;
1295 return NVVM::WGMMATypes::bf16;
1296 if (isa<Float8E4M3FNType>(elemType))
1297 return NVVM::WGMMATypes::e4m3;
1298 if (isa<Float8E5M2Type>(elemType))
1299 return NVVM::WGMMATypes::e5m2;
1301 return NVVM::WGMMATypes::b1;
1303 return NVVM::WGMMATypes::s8;
1305 return NVVM::WGMMATypes::u8;
1307 return NVVM::WGMMATypes::s32;
1308 llvm_unreachable(
"unsupported type");
1315 generateWgmmaLayout(std::optional<bool>
transpose)
const {
1322 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1327 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1329 NVVM::WGMMAScaleOut::one);
1332 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1334 NVVM::WGMMAScaleIn::one);
1360 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1361 MemRefType matrixTypeA = op.getDescriptorA().
getType().getTensor();
1362 Type elemA = matrixTypeA.getElementType();
1364 int tileShapeA = matrixTypeA.getDimSize(1);
1365 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1367 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1368 <<
"] [wgmma descriptors] Descriptor A + "
1369 << incrementVal <<
" | \t ");
1372 return makeAdd(desc, makeI64Const(b, incrementVal));
1386 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1387 MemRefType matrixTypeB = op.getDescriptorB().
getType().getTensor();
1388 Type elemB = matrixTypeB.getElementType();
1390 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1392 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1395 return makeAdd(desc, makeI64Const(b, incrementVal));
1400 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1401 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1402 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1403 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1404 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1405 << (iterationK * wgmmaK) <<
":"
1406 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1407 <<
" B[" << (iterationK * wgmmaK) <<
":"
1408 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1409 << wgmmaN <<
"])\n");
1411 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1412 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1414 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1415 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1417 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1418 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1420 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1421 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1423 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1424 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1425 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1426 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1427 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1430 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1432 return b.
create<NVVM::WgmmaMmaAsyncOp>(
1433 matrixC.
getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1434 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1439 Value generateWgmmaGroup() {
1441 b.
create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
1445 for (
int i = 0; i < iterationM; ++i) {
1446 Value matrixC = b.
create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1447 for (
int j = 0;
j < iterationN; ++
j)
1448 for (
int k = 0; k < iterationK; ++k)
1449 matrixC = generateWgmma(i,
j, k, matrixC);
1450 wgmmaResults.push_back(matrixC);
1453 wgmmaResult = b.
create<LLVM::InsertValueOp>(wgmmaResult.
getType(),
1454 wgmmaResult, matrix, idx);
1462 : op(op), b(b), adaptor(adaptor) {
1464 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1465 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1466 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1467 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1468 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1469 << totalK <<
"][" << totalN <<
"] ---===\n");
1474 op.getDescriptorA().getType().getTensor().getElementType());
1477 iterationM = totalM / wgmmaM;
1478 iterationN = totalN / wgmmaN;
1479 iterationK = totalK / wgmmaK;
1487 Value generateWarpgroupMma() {
1488 b.
create<NVVM::WgmmaFenceAlignedOp>();
1489 Value wgmmaResult = generateWgmmaGroup();
1490 b.
create<NVVM::WgmmaGroupSyncAlignedOp>();
1491 b.
create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1496 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1501 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1504 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1512 struct NVGPUWarpgroupMmaStoreOpLowering
1558 auto makeConst = [&](int32_t index) ->
Value {
1561 Value c1 = makeConst(1);
1562 Value c2 = makeConst(2);
1563 Value c4 = makeConst(4);
1564 Value c8 = makeConst(8);
1565 Value c16 = makeConst(16);
1575 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1580 Value idy1 = b.
create<arith::IndexCastOp>(it, makeAdd(y, c1));
1581 Value d0 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i);
1582 Value d1 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1588 Value laneId = b.
create<LLVM::URemOp>(i32, tidx, warpSize);
1589 Value warpId = b.
create<LLVM::UDivOp>(i32, tidx, warpSize);
1590 Value lane4Id = b.
create<LLVM::UDivOp>(i32, laneId, c4);
1591 Value lane4modId = b.
create<LLVM::URemOp>(i32, laneId, c4);
1593 Value tj = makeMul(lane4modId, c2);
1594 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1596 ti = makeAdd(ti, makeConst(offset));
1598 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1601 constexpr
unsigned numAdjacentRegisters = 2;
1603 constexpr
unsigned numStackedMatrices = 2;
1605 size_t storeCount = (structType.getBody().size() /
1606 (numStackedMatrices * numAdjacentRegisters));
1608 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1609 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1610 for (
size_t j = 0;
j < storeCount; ++
j) {
1611 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1612 size_t structIndex = (i * numAdjacentRegisters) +
1613 (
j * (numStackedMatrices * numAdjacentRegisters));
1614 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1620 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1624 Value matriDValue = adaptor.getMatrixD();
1625 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1627 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1628 Value innerStructValue = b.
create<LLVM::ExtractValueOp>(matriDValue, idx);
1629 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1630 offset += structType.getBody().size();
1637 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1642 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1645 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1646 getTypeConverter()->convertType(op.getMatrixC().getType()));
1647 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1651 Value packStruct = b.
create<LLVM::PoisonOp>(packStructType);
1655 auto structType = cast<LLVM::LLVMStructType>(s);
1656 Value structValue = b.
create<LLVM::ExtractValueOp>(packStruct, idx);
1657 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1658 structValue = b.
create<LLVM::InsertValueOp>(
1661 innerStructs.push_back(structValue);
1665 packStruct = b.
create<LLVM::InsertValueOp>(packStruct.getType(),
1666 packStruct, matrix, idx);
1673 struct NVGPUTmaFenceOpLowering
1677 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1682 Value tensormapSize =
1689 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1695 struct NVGPUTmaPrefetchOpLowering
1699 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1702 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1710 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1715 VectorType inTy = op.getIn().getType();
1717 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1718 Value ret1DVec = b.
create<LLVM::PoisonOp>(llvm1DVectorTy);
1719 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1720 for (
int i = 0; i < numElems; i++) {
1722 Value elem = b.
create<LLVM::ExtractElementOp>(inVec, idx);
1723 Value dst = b.
create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1724 ret1DVec = b.
create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1728 if (inTy.getRank() == 1) {
1729 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1733 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1735 OpAdaptor adaptor(operands);
1736 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1746 NVGPUMBarrierCreateLowering,
1747 NVGPUMBarrierInitLowering,
1748 NVGPUMBarrierGetLowering,
1749 NVGPUMBarrierArriveLowering,
1750 NVGPUMBarrierArriveNoCompleteLowering,
1751 NVGPUMBarrierTestWaitLowering,
1752 NVGPUMBarrierTryWaitParityLowering,
1753 NVGPUTmaAsyncLoadOpLowering,
1754 NVGPUTmaAsyncStoreOpLowering,
1755 NVGPUTmaCreateDescriptorOpLowering,
1756 NVGPUTmaPrefetchOpLowering,
1757 NVGPUTmaFenceOpLowering,
1758 NVGPUMBarrierArriveExpectTxLowering,
1759 NVGPUGenerateWarpgroupDescriptorLowering,
1760 NVGPUWarpgroupMmaOpLowering,
1761 NVGPUWarpgroupMmaStoreOpLowering,
1762 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1763 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1764 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1765 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.