28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
52 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
55 return LLVM::TruncOp::create(b, b.
getI32Type(), value);
62 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
70 if (a.getElementType() == f16x2Ty) {
71 return LLVM::LLVMStructType::getLiteral(
74 if (a.getElementType() == i32x2Ty) {
75 return LLVM::LLVMStructType::getLiteral(
79 if (a.getElementType() == f64x2Ty) {
80 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82 if (a.getElementType() == f32x2Ty) {
83 return LLVM::LLVMStructType::getLiteral(
88 return LLVM::LLVMStructType::getLiteral(
91 return vectorResultType;
103 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
114 auto makeConst = [&](int32_t index) ->
Value {
124 if (arrayType.getElementType() == f16x2Ty ||
125 arrayType.getElementType() == f32x1Ty) {
126 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
128 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
130 loc, arrayType.getElementType(), el);
131 elements.push_back(el);
139 if (arrayType.getElementType() == i32x2Ty ||
140 arrayType.getElementType() == f64x2Ty ||
141 arrayType.getElementType() == f32x2Ty) {
143 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
147 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
148 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
150 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
152 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
154 elements.push_back(vec);
159 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
161 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
167 return intrinsicResult;
177 NVVM::MMATypes operandPtxType) {
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
188 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == f32x1Ty &&
196 operandPtxType == NVVM::MMATypes::tf32)) {
197 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
204 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206 innerArrayTy.getElementType() == f64Ty ||
207 innerArrayTy.getElementType() == f32Ty)) {
208 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209 idx < innerSize; idx++) {
210 result.push_back(LLVM::ExtractElementOp::create(
216 result.push_back(toUse);
223 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224 barrierType.getMemorySpace()));
229 nvgpu::MBarrierGroupType barrierType) {
234 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
242 nvgpu::MBarrierGroupType barrierType) {
244 MemRefLayoutAttrInterface layout;
255 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
266 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
267 if (!vectorResultType) {
271 vectorResultType.getElementType());
273 int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 Type ldMatrixResultType;
276 if (num32BitRegs > 1) {
277 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
283 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
286 adaptor.getSrcMemref(), adaptor.getIndices());
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row);
297 Type finalResultType = typeConverter->convertType(vectorResultType);
298 Value result = LLVM::PoisonOp::create(b, finalResultType);
299 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
303 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
304 result = LLVM::InsertValueOp::create(b, result, casted, i);
314 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
317 return NVVM::MMATypes::s8;
319 return NVVM::MMATypes::s4;
321 return NVVM::MMATypes::f16;
323 return NVVM::MMATypes::f64;
325 return NVVM::MMATypes::tf32;
333 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
338 VectorType aType = op.getMatrixA().getType();
339 VectorType bType = op.getMatrixA().getType();
340 VectorType cType = op.getMatrixC().getType();
342 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
345 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
346 if (aType.getElementType().isF32() && !tf32Enabled)
349 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
350 if (failed(ptxTypeA))
351 return op->emitOpError(
"failed to deduce operand PTX types");
352 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
353 if (failed(ptxTypeB))
354 return op->emitOpError(
"failed to deduce operand PTX types");
355 std::optional<NVVM::MMATypes> ptxTypeC =
356 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
359 return op->emitError(
360 "could not infer the PTX type for the accumulator/result");
363 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364 if (isa<IntegerType>(aType.getElementType()))
365 overflow = NVVM::MMAIntOverflow::satfinite;
374 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376 typeConverter->convertType(op->getResultTypes()[0]));
377 Value intrinsicResult =
378 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
383 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 std::array<NVVM::MMALayout, 2>{
386 NVVM::MMALayout::row, NVVM::MMALayout::col});
388 desiredRetTy, intrinsicResult,
394 struct ConvertNVGPUToNVVMPass
395 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400 arith::ArithDialect>();
403 void runOnOperation()
override {
409 converter, [](gpu::AddressSpace space) ->
unsigned {
411 case gpu::AddressSpace::Global:
412 return static_cast<unsigned>(
414 case gpu::AddressSpace::Workgroup:
415 return static_cast<unsigned>(
417 case gpu::AddressSpace::Private:
420 llvm_unreachable(
"unknown address space enum value");
426 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
429 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
430 Type elemType = type.getFragmented().getElementType();
431 int64_t sizeM = type.getFragmented().getDimSize(0);
432 int64_t sizeN = type.getFragmented().getDimSize(1);
436 numMembers = sizeN / 2;
437 else if (elemType.
isF16())
438 numMembers = sizeN / 4;
440 llvm_unreachable(
"unsupported type for warpgroup accumulator");
443 for (
unsigned i = 0; i < numMembers; i++)
444 innerStructBody.push_back(elemType);
445 auto innerStructType =
446 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
450 structBody.push_back(innerStructType);
453 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
454 return converter.convertType(convertedType);
456 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
459 converter.addConversion(
460 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
463 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
464 return converter.convertType(
467 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
472 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473 target.addLegalDialect<::mlir::arith::ArithDialect>();
474 target.addLegalDialect<::mlir::memref::MemRefDialect>();
475 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
485 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
489 llvm::raw_string_ostream ss(str);
490 for (
unsigned i = 0; i < matCSize; i++)
492 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
504 static std::string buildMmaSparseAsmString(
505 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
506 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
507 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
508 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
509 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
510 return NVVM::stringifyMMATypes(ptxType);
514 llvm::raw_string_ostream ss(asmStr);
515 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
516 << shape[2] <<
".row.col.";
519 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
521 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
522 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
523 unsigned asmArgIdx = 0;
527 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529 for (
unsigned i = 0; i < arrSize; i++)
530 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
533 ss <<
"$" << asmArgIdx++ <<
",";
534 assert(metaDataSelector <= 1);
535 ss <<
"0x" << metaDataSelector <<
";";
541 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
544 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
546 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
547 Type intrinsicResultType) {
548 auto asmDialectAttr =
551 const unsigned matASize = unpackedAData.size();
552 const unsigned matBSize = unpackedB.size();
553 const unsigned matCSize = unpackedC.size();
555 std::string asmStr = buildMmaSparseAsmString(
556 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
557 ptxTypeD, overflow, metadataSelector);
558 std::string constraintStr =
559 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
562 asmVals.reserve(matASize + matBSize + matCSize + 1);
564 llvm::append_range(asmVals, args);
565 asmVals.push_back(indexData);
567 return LLVM::InlineAsmOp::create(b,
580 struct NVGPUMmaSparseSyncLowering
585 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
590 VectorType aType = op.getMatrixA().getType();
591 VectorType bType = op.getMatrixB().getType();
592 VectorType cType = op.getMatrixC().getType();
594 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595 if (failed(ptxTypeA))
596 return op->emitOpError(
"failed to deduce operand PTX types");
597 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598 if (failed(ptxTypeB))
599 return op->emitOpError(
"failed to deduce operand PTX types");
600 std::optional<NVVM::MMATypes> ptxTypeC =
601 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
604 return op->emitError(
605 "could not infer the PTX type for the accumulator/result");
608 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609 if (aType.getElementType().isF32() && !tf32Enabled)
613 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614 if (isa<IntegerType>(aType.getElementType()))
615 overflow = NVVM::MMAIntOverflow::satfinite;
624 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
626 typeConverter->convertType(op->getResultTypes()[0]));
629 Value sparseMetadata = adaptor.getSparseMetadata();
631 return op->emitOpError() <<
"Expected metadata type to be LLVM "
632 "VectorType of 2 i16 elements";
634 LLVM::BitcastOp::create(b, rewriter.
getI32Type(), sparseMetadata);
636 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
637 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
638 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
640 if (failed(intrinsicResult))
643 assert((*intrinsicResult).getNumResults() == 1 &&
644 "expected inline asm op returns a single LLVM struct type");
647 (*intrinsicResult)->getResult(0), rewriter));
652 struct NVGPUAsyncCopyLowering
658 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
662 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
665 adaptor.getDst(), adaptor.getDstIndices());
666 FailureOr<unsigned> dstAddressSpace =
667 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668 if (failed(dstAddressSpace))
670 loc,
"destination memref address space not convertible to integer");
672 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
673 FailureOr<unsigned> srcAddressSpace =
674 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
675 if (failed(srcAddressSpace))
677 loc,
"source memref address space not convertible to integer");
681 adaptor.getSrcIndices());
685 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
686 int64_t dstElements = adaptor.getDstElements().getZExtValue();
687 int64_t sizeInBytes =
688 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
693 Value srcBytes = adaptor.getSrcElements();
701 Value bitwidth = LLVM::ConstantOp::create(
704 Value srcElementsI32 = LLVM::TruncOp::create(b, b.
getI32Type(), srcBytes);
705 srcBytes = LLVM::LShrOp::create(
706 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
710 NVVM::LoadCacheModifierKind cacheModifier =
711 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
712 ? NVVM::LoadCacheModifierKind::CG
713 : NVVM::LoadCacheModifierKind::CA;
715 NVVM::CpAsyncOp::create(
729 struct NVGPUAsyncCreateGroupLowering
735 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
737 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
739 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
747 struct NVGPUAsyncWaitLowering
753 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
756 int32_t numGroups = adaptor.getNumGroups().value_or(0);
757 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
764 struct NVGPUMBarrierCreateLowering
768 template <
typename moduleT>
771 MemRefType barrierType)
const {
775 auto global = memref::GlobalOp::create(
776 rewriter, funcOp->
getLoc(),
"__mbarrier",
782 symbolTable.insert(global);
787 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
791 rewriter.
getContext(), op.getBarriers().getType());
793 memref::GlobalOp global;
795 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
807 template <
typename SourceOp>
813 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
816 MemRefType mbarrierMemrefType =
819 rewriter, b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
823 struct NVGPUMBarrierGetLowering
824 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
825 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
828 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
831 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
833 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
834 adaptor.getMbarId(), rewriter);
835 Type resType = op.getMbarrierPointer().getType();
842 struct NVGPUMBarrierInitLowering
843 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
844 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
847 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
850 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
852 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
853 adaptor.getMbarId(), rewriter);
857 op, barrier, count, adaptor.getPredicate());
860 adaptor.getPredicate());
867 struct NVGPUMBarrierArriveLowering
868 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
869 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
871 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
875 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
876 adaptor.getMbarId(), rewriter);
877 Type tokenType = getTypeConverter()->convertType(
892 struct NVGPUMBarrierArriveNoCompleteLowering
893 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
894 using MBarrierBasePattern<
895 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
897 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
901 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
902 adaptor.getMbarId(), rewriter);
903 Type tokenType = getTypeConverter()->convertType(
908 op, tokenType, barrier, count);
911 op, tokenType, barrier, count);
918 struct NVGPUMBarrierTestWaitLowering
919 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
920 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
922 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
926 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
927 adaptor.getMbarId(), rewriter);
931 op, retType, barrier, adaptor.getToken());
934 op, retType, barrier, adaptor.getToken());
940 struct NVGPUMBarrierArriveExpectTxLowering
941 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
942 using MBarrierBasePattern<
943 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
945 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
949 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
950 adaptor.getMbarId(), rewriter);
955 op, barrier, txcount, adaptor.getPredicate());
960 op, barrier, txcount, adaptor.getPredicate());
965 struct NVGPUMBarrierTryWaitParityLowering
966 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
967 using MBarrierBasePattern<
968 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
970 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
974 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
975 adaptor.getMbarId(), rewriter);
978 LLVM::ZExtOp::create(b, b.
getI32Type(), adaptor.getPhaseParity());
982 op, barrier, phase, ticks);
992 struct NVGPUTmaAsyncLoadOpLowering
993 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
994 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
996 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
999 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
1001 adaptor.getDst(), {});
1003 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1004 adaptor.getMbarId(), rewriter);
1011 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1013 adaptor.getPredicate());
1018 struct NVGPUTmaAsyncStoreOpLowering
1019 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1020 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1022 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1025 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1027 adaptor.getSrc(), {});
1034 op, adaptor.getTensorMapDescriptor(), dest, coords,
1035 adaptor.getPredicate());
1040 struct NVGPUGenerateWarpgroupDescriptorLowering
1046 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1051 nvgpu::TensorMapSwizzleKind swizzleKind =
1052 op.getTensorMap().getType().getSwizzle();
1055 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1056 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1057 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1060 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1061 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1062 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1066 auto makeConst = [&](uint64_t index) ->
Value {
1069 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1070 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1072 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1073 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1075 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1076 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1079 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1080 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1081 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1082 uint64_t offsetVal = 0;
1084 Value strideDim = makeConst(strideDimVal);
1085 Value leadDim = makeConst(leadDimVal);
1088 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1089 adaptor.getTensor(), {});
1090 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1092 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1094 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1095 startLeadBit = 16, startBaseAddrBit = 0;
1096 Value dsc = makeConst(0);
1098 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1100 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1102 dsc = insertBit(dsc, strideDim, startStrideBit);
1104 dsc = insertBit(dsc, leadDim, startLeadBit);
1106 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1108 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1109 <<
"leading_off:" << leadDimVal <<
"\t"
1110 <<
"stride_off :" << strideDimVal <<
"\t"
1111 <<
"base_offset:" << offsetVal <<
"\t"
1112 <<
"layout_type:" << swizzle <<
" ("
1113 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1114 <<
")\n start_addr : " << baseAddr <<
"\n");
1130 enum CUtensorMapDataTypeEnum {
1131 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1132 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1133 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1134 CU_TENSOR_MAP_DATA_TYPE_INT32,
1135 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1136 CU_TENSOR_MAP_DATA_TYPE_INT64,
1137 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1138 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1139 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1140 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1141 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1142 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1143 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1147 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1149 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1151 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1153 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1155 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1157 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1159 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1161 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1163 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1165 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1167 llvm_unreachable(
"Not supported data type");
1170 struct NVGPUTmaCreateDescriptorOpLowering
1175 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1181 Value tensorElementType =
1182 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1183 auto promotedOperands = getTypeConverter()->promoteOperands(
1184 b.
getLoc(), op->getOperands(), adaptor.getOperands(), b);
1186 Value boxArrayPtr = LLVM::AllocaOp::create(
1187 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1188 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1189 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1190 boxArrayPtr, makeI64Const(b, index));
1191 LLVM::StoreOp::create(b, value, gep);
1194 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1197 arguments.push_back(promotedOperands[0]);
1198 arguments.push_back(promotedOperands[1]);
1199 arguments.push_back(tensorElementType);
1200 arguments.push_back(
1201 makeI64Const(b, (
int)desc.getInterleave()));
1202 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1203 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1204 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1205 arguments.push_back(boxArrayPtr);
1219 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1221 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1228 struct NVGPUWarpgroupMmaOpLowering
1252 class WarpgroupGemm {
1253 nvgpu::WarpgroupMmaOp op;
1258 int64_t totalM, totalN, totalK;
1261 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1264 int iterationM = 0, iterationN = 0, iterationK = 0;
1269 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1272 if (inputElemType.
isTF32()) {
1274 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1276 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1279 }
else if (inputElemType.
isInteger(1)) {
1282 llvm_unreachable(
"msg: not supported K shape");
1284 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1285 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1289 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1290 bool useF32 =
false)
const {
1291 auto getWgmmaType = [=](
Type elemType) {
1293 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1294 if (elemType.
isF16())
1295 return NVVM::WGMMATypes::f16;
1297 return NVVM::WGMMATypes::bf16;
1298 if (isa<Float8E4M3FNType>(elemType))
1299 return NVVM::WGMMATypes::e4m3;
1300 if (isa<Float8E5M2Type>(elemType))
1301 return NVVM::WGMMATypes::e5m2;
1303 return NVVM::WGMMATypes::b1;
1305 return NVVM::WGMMATypes::s8;
1307 return NVVM::WGMMATypes::u8;
1309 return NVVM::WGMMATypes::s32;
1310 llvm_unreachable(
"unsupported type");
1317 generateWgmmaLayout(std::optional<bool> transpose)
const {
1318 if (transpose.value_or(
false))
1324 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1329 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1331 NVVM::WGMMAScaleOut::one);
1334 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1336 NVVM::WGMMAScaleIn::one);
1341 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1362 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1363 MemRefType matrixTypeA = op.getDescriptorA().
getType().getTensor();
1364 Type elemA = matrixTypeA.getElementType();
1366 int tileShapeA = matrixTypeA.getDimSize(1);
1367 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1369 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1370 <<
"] [wgmma descriptors] Descriptor A + "
1371 << incrementVal <<
" | \t ");
1374 return makeAdd(desc, makeI64Const(b, incrementVal));
1388 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1389 MemRefType matrixTypeB = op.getDescriptorB().
getType().getTensor();
1390 Type elemB = matrixTypeB.getElementType();
1392 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1394 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1397 return makeAdd(desc, makeI64Const(b, incrementVal));
1402 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1403 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1404 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1405 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1406 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1407 << (iterationK * wgmmaK) <<
":"
1408 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1409 <<
" B[" << (iterationK * wgmmaK) <<
":"
1410 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1411 << wgmmaN <<
"])\n");
1413 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1414 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1416 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1417 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1419 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1420 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1422 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1423 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1425 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1426 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1427 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1428 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1429 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1432 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1434 return NVVM::WgmmaMmaAsyncOp::create(
1435 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1436 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1441 Value generateWgmmaGroup() {
1443 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1447 for (
int i = 0; i < iterationM; ++i) {
1449 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1450 for (
int j = 0;
j < iterationN; ++
j)
1451 for (
int k = 0; k < iterationK; ++k)
1452 matrixC = generateWgmma(i,
j, k, matrixC);
1453 wgmmaResults.push_back(matrixC);
1456 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1457 wgmmaResult, matrix, idx);
1465 : op(op), b(b), adaptor(adaptor) {
1467 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1468 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1469 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1470 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1471 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1472 << totalK <<
"][" << totalN <<
"] ---===\n");
1477 op.getDescriptorA().getType().getTensor().getElementType());
1480 iterationM = totalM / wgmmaM;
1481 iterationN = totalN / wgmmaN;
1482 iterationK = totalK / wgmmaK;
1490 Value generateWarpgroupMma() {
1491 NVVM::WgmmaFenceAlignedOp::create(b);
1492 Value wgmmaResult = generateWgmmaGroup();
1493 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1494 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1499 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1504 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1507 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1515 struct NVGPUWarpgroupMmaStoreOpLowering
1561 auto makeConst = [&](int32_t index) ->
Value {
1564 Value c1 = makeConst(1);
1565 Value c2 = makeConst(2);
1566 Value c4 = makeConst(4);
1567 Value c8 = makeConst(8);
1568 Value c16 = makeConst(16);
1572 return LLVM::MulOp::create(b, lhs.
getType(), lhs, rhs);
1575 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1578 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1581 Value idx = arith::IndexCastOp::create(b, it, x);
1582 Value idy0 = arith::IndexCastOp::create(b, it, y);
1583 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1584 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1585 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1586 memref::StoreOp::create(b, d0, memref,
ValueRange{idx, idy0});
1587 memref::StoreOp::create(b, d1, memref,
ValueRange{idx, idy1});
1590 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1591 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1592 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1593 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1594 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1596 Value tj = makeMul(lane4modId, c2);
1597 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1599 ti = makeAdd(ti, makeConst(offset));
1601 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1604 constexpr
unsigned numAdjacentRegisters = 2;
1606 constexpr
unsigned numStackedMatrices = 2;
1608 size_t storeCount = (structType.getBody().size() /
1609 (numStackedMatrices * numAdjacentRegisters));
1611 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1612 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1613 for (
size_t j = 0;
j < storeCount; ++
j) {
1614 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1615 size_t structIndex = (i * numAdjacentRegisters) +
1616 (
j * (numStackedMatrices * numAdjacentRegisters));
1617 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1623 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1627 Value matriDValue = adaptor.getMatrixD();
1628 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1630 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1631 Value innerStructValue =
1632 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1633 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1634 offset += structType.getBody().size();
1641 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1646 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1649 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1650 getTypeConverter()->convertType(op.getMatrixC().getType()));
1651 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1654 Value zero = LLVM::ConstantOp::create(b, elemType, b.
getZeroAttr(elemType));
1655 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1659 auto structType = cast<LLVM::LLVMStructType>(s);
1660 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1661 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1662 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1665 innerStructs.push_back(structValue);
1669 packStruct = LLVM::InsertValueOp::create(b, packStruct.
getType(),
1670 packStruct, matrix, idx);
1677 struct NVGPUTmaFenceOpLowering
1681 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1686 Value tensormapSize =
1693 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1699 struct NVGPUTmaPrefetchOpLowering
1703 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1706 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1714 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1719 VectorType inTy = op.getIn().getType();
1721 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1722 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1723 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1724 for (
int i = 0; i < numElems; i++) {
1726 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1727 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1728 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1732 if (inTy.getRank() == 1) {
1733 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1737 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1739 OpAdaptor adaptor(operands);
1740 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1750 NVGPUMBarrierCreateLowering,
1751 NVGPUMBarrierInitLowering,
1752 NVGPUMBarrierGetLowering,
1753 NVGPUMBarrierArriveLowering,
1754 NVGPUMBarrierArriveNoCompleteLowering,
1755 NVGPUMBarrierTestWaitLowering,
1756 NVGPUMBarrierTryWaitParityLowering,
1757 NVGPUTmaAsyncLoadOpLowering,
1758 NVGPUTmaAsyncStoreOpLowering,
1759 NVGPUTmaCreateDescriptorOpLowering,
1760 NVGPUTmaPrefetchOpLowering,
1761 NVGPUTmaFenceOpLowering,
1762 NVGPUMBarrierArriveExpectTxLowering,
1763 NVGPUGenerateWarpgroupDescriptorLowering,
1764 NVGPUWarpgroupMmaOpLowering,
1765 NVGPUWarpgroupMmaStoreOpLowering,
1766 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1767 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1768 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1769 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.