28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
37 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38 #include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(b, b.
getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
113 auto makeConst = [&](int32_t index) ->
Value {
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
187 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
215 result.push_back(toUse);
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
228 nvgpu::MBarrierGroupType barrierType) {
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
241 nvgpu::MBarrierGroupType barrierType) {
243 MemRefLayoutAttrInterface layout;
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
270 vectorResultType.getElementType());
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 adaptor.getSrcMemref(), adaptor.getIndices());
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 shape, NVVM::LdStMatrixEltType::B16);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = LLVM::PoisonOp::create(b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
304 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305 result = LLVM::InsertValueOp::create(b, result, casted, i);
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395 struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 void runOnOperation()
override {
405 converter, [](gpu::AddressSpace space) ->
unsigned {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
409 case gpu::AddressSpace::Workgroup:
410 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
411 case gpu::AddressSpace::Private:
414 llvm_unreachable(
"unknown address space enum value");
415 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
420 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
423 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
424 Type elemType = type.getFragmented().getElementType();
425 int64_t sizeM = type.getFragmented().getDimSize(0);
426 int64_t sizeN = type.getFragmented().getDimSize(1);
430 numMembers = sizeN / 2;
431 else if (elemType.
isF16())
432 numMembers = sizeN / 4;
434 llvm_unreachable(
"unsupported type for warpgroup accumulator");
437 for (
unsigned i = 0; i < numMembers; i++)
438 innerStructBody.push_back(elemType);
439 auto innerStructType =
440 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
444 structBody.push_back(innerStructType);
447 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
448 return converter.convertType(convertedType);
450 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
453 converter.addConversion(
454 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
457 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
458 return converter.convertType(
461 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
466 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
467 target.addLegalDialect<::mlir::arith::ArithDialect>();
468 target.addLegalDialect<::mlir::memref::MemRefDialect>();
469 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
479 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
483 llvm::raw_string_ostream ss(str);
484 for (
unsigned i = 0; i < matCSize; i++)
486 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
498 static std::string buildMmaSparseAsmString(
499 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
500 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
501 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
502 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
503 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
504 return NVVM::stringifyMMATypes(ptxType);
508 llvm::raw_string_ostream ss(asmStr);
509 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
510 << shape[2] <<
".row.col.";
513 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
515 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
516 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
517 unsigned asmArgIdx = 0;
521 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
523 for (
unsigned i = 0; i < arrSize; i++)
524 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
527 ss <<
"$" << asmArgIdx++ <<
",";
528 assert(metaDataSelector <= 1);
529 ss <<
"0x" << metaDataSelector <<
";";
535 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
537 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
538 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
540 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
541 Type intrinsicResultType) {
542 auto asmDialectAttr =
545 const unsigned matASize = unpackedAData.size();
546 const unsigned matBSize = unpackedB.size();
547 const unsigned matCSize = unpackedC.size();
549 std::string asmStr = buildMmaSparseAsmString(
550 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
551 ptxTypeD, overflow, metadataSelector);
552 std::string constraintStr =
553 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
556 asmVals.reserve(matASize + matBSize + matCSize + 1);
558 llvm::append_range(asmVals, args);
559 asmVals.push_back(indexData);
561 return LLVM::InlineAsmOp::create(b,
574 struct NVGPUMmaSparseSyncLowering
579 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
584 VectorType aType = op.getMatrixA().getType();
585 VectorType bType = op.getMatrixB().getType();
586 VectorType cType = op.getMatrixC().getType();
588 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
590 return op->emitOpError(
"failed to deduce operand PTX types");
591 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
593 return op->emitOpError(
"failed to deduce operand PTX types");
594 std::optional<NVVM::MMATypes> ptxTypeC =
595 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
598 return op->emitError(
599 "could not infer the PTX type for the accumulator/result");
602 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
603 if (aType.getElementType().isF32() && !tf32Enabled)
607 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
608 if (isa<IntegerType>(aType.getElementType()))
609 overflow = NVVM::MMAIntOverflow::satfinite;
618 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
620 typeConverter->convertType(op->getResultTypes()[0]));
623 Value sparseMetadata = adaptor.getSparseMetadata();
625 return op->emitOpError() <<
"Expected metadata type to be LLVM "
626 "VectorType of 2 i16 elements";
628 LLVM::BitcastOp::create(b, rewriter.
getI32Type(), sparseMetadata);
630 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
631 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
632 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
634 if (
failed(intrinsicResult))
637 assert((*intrinsicResult).getNumResults() == 1 &&
638 "expected inline asm op returns a single LLVM struct type");
641 (*intrinsicResult)->getResult(0), rewriter));
646 struct NVGPUAsyncCopyLowering
652 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
656 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
659 adaptor.getDst(), adaptor.getDstIndices());
660 FailureOr<unsigned> dstAddressSpace =
661 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
662 if (
failed(dstAddressSpace))
664 loc,
"destination memref address space not convertible to integer");
666 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
667 FailureOr<unsigned> srcAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
669 if (
failed(srcAddressSpace))
671 loc,
"source memref address space not convertible to integer");
675 adaptor.getSrcIndices());
678 op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::Global));
679 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
680 int64_t dstElements = adaptor.getDstElements().getZExtValue();
681 int64_t sizeInBytes =
682 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
687 Value srcBytes = adaptor.getSrcElements();
695 Value bitwidth = LLVM::ConstantOp::create(
698 Value srcElementsI32 = LLVM::TruncOp::create(b, b.
getI32Type(), srcBytes);
699 srcBytes = LLVM::LShrOp::create(
700 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
704 NVVM::LoadCacheModifierKind cacheModifier =
705 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
706 ? NVVM::LoadCacheModifierKind::CG
707 : NVVM::LoadCacheModifierKind::CA;
709 NVVM::CpAsyncOp::create(
723 struct NVGPUAsyncCreateGroupLowering
729 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
731 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
733 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
741 struct NVGPUAsyncWaitLowering
747 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
750 int32_t numGroups = adaptor.getNumGroups().value_or(0);
751 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
758 struct NVGPUMBarrierCreateLowering
762 template <
typename moduleT>
765 MemRefType barrierType)
const {
769 auto global = memref::GlobalOp::create(
770 rewriter, funcOp->
getLoc(),
"__mbarrier",
776 symbolTable.insert(global);
781 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
785 rewriter.
getContext(), op.getBarriers().getType());
787 memref::GlobalOp global;
789 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
801 template <
typename SourceOp>
807 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
810 MemRefType mbarrierMemrefType =
813 rewriter, b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
817 struct NVGPUMBarrierGetLowering
818 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
819 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
822 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
825 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
827 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
828 adaptor.getMbarId(), rewriter);
829 Type resType = op.getMbarrierPointer().getType();
836 struct NVGPUMBarrierInitLowering
837 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
838 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
841 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
844 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
846 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
847 adaptor.getMbarId(), rewriter);
851 op, barrier, count, adaptor.getPredicate());
854 adaptor.getPredicate());
861 struct NVGPUMBarrierArriveLowering
862 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
863 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
865 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
869 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
870 adaptor.getMbarId(), rewriter);
871 Type tokenType = getTypeConverter()->convertType(
886 struct NVGPUMBarrierArriveNoCompleteLowering
887 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
888 using MBarrierBasePattern<
889 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
891 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
895 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
896 adaptor.getMbarId(), rewriter);
897 Type tokenType = getTypeConverter()->convertType(
902 op, tokenType, barrier, count);
905 op, tokenType, barrier, count);
912 struct NVGPUMBarrierTestWaitLowering
913 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
914 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
916 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
920 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
921 adaptor.getMbarId(), rewriter);
925 op, retType, barrier, adaptor.getToken());
928 op, retType, barrier, adaptor.getToken());
934 struct NVGPUMBarrierArriveExpectTxLowering
935 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
936 using MBarrierBasePattern<
937 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
939 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
943 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
944 adaptor.getMbarId(), rewriter);
949 op, barrier, txcount, adaptor.getPredicate());
954 op, barrier, txcount, adaptor.getPredicate());
959 struct NVGPUMBarrierTryWaitParityLowering
960 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
961 using MBarrierBasePattern<
962 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
964 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
968 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
969 adaptor.getMbarId(), rewriter);
972 LLVM::ZExtOp::create(b, b.
getI32Type(), adaptor.getPhaseParity());
976 op, barrier, phase, ticks);
986 struct NVGPUTmaAsyncLoadOpLowering
987 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
988 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
990 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
993 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
995 adaptor.getDst(), {});
1001 static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
1002 dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
1005 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1006 adaptor.getMbarId(), rewriter);
1015 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1017 NVVM::TMALoadMode::TILE,
1020 adaptor.getPredicate());
1025 struct NVGPUTmaAsyncStoreOpLowering
1026 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1027 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1029 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1032 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1034 adaptor.getSrc(), {});
1042 op, adaptor.getTensorMapDescriptor(), dest, coords,
Value{},
1043 NVVM::TMAStoreMode::TILE,
1044 adaptor.getPredicate());
1049 struct NVGPUGenerateWarpgroupDescriptorLowering
1055 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1060 nvgpu::TensorMapSwizzleKind swizzleKind =
1061 op.getTensorMap().getType().getSwizzle();
1064 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1065 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1066 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1069 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1070 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1071 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1075 auto makeConst = [&](uint64_t index) ->
Value {
1078 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1079 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1081 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1082 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1084 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1085 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1088 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1089 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1090 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1091 uint64_t offsetVal = 0;
1093 Value strideDim = makeConst(strideDimVal);
1094 Value leadDim = makeConst(leadDimVal);
1097 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1098 adaptor.getTensor(), {});
1099 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1101 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1103 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1104 startLeadBit = 16, startBaseAddrBit = 0;
1105 Value dsc = makeConst(0);
1107 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1109 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1111 dsc = insertBit(dsc, strideDim, startStrideBit);
1113 dsc = insertBit(dsc, leadDim, startLeadBit);
1115 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1117 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1118 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1119 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1120 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1121 <<
")\n start_addr : " << baseAddr;
1137 enum CUtensorMapDataTypeEnum {
1138 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1139 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1140 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1141 CU_TENSOR_MAP_DATA_TYPE_INT32,
1142 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1143 CU_TENSOR_MAP_DATA_TYPE_INT64,
1144 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1145 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1146 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1147 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1148 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1149 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1150 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1154 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1156 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1158 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1160 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1162 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1164 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1166 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1168 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1170 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1172 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1174 llvm_unreachable(
"Not supported data type");
1177 struct NVGPUTmaCreateDescriptorOpLowering
1182 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1188 Value tensorElementType =
1189 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1190 auto promotedOperands = getTypeConverter()->promoteOperands(
1191 b.
getLoc(), op->getOperands(), adaptor.getOperands(), b);
1193 Value boxArrayPtr = LLVM::AllocaOp::create(
1194 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1195 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1196 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1197 boxArrayPtr, makeI64Const(b, index));
1198 LLVM::StoreOp::create(b, value, gep);
1201 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1204 arguments.push_back(promotedOperands[0]);
1205 arguments.push_back(promotedOperands[1]);
1206 arguments.push_back(tensorElementType);
1207 arguments.push_back(
1208 makeI64Const(b, (
int)desc.getInterleave()));
1209 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1210 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1211 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1212 arguments.push_back(boxArrayPtr);
1226 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1228 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1235 struct NVGPUWarpgroupMmaOpLowering
1259 class WarpgroupGemm {
1260 nvgpu::WarpgroupMmaOp op;
1265 int64_t totalM, totalN, totalK;
1268 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1271 int iterationM = 0, iterationN = 0, iterationK = 0;
1276 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1279 if (inputElemType.
isTF32()) {
1281 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1283 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1286 }
else if (inputElemType.
isInteger(1)) {
1289 llvm_unreachable(
"msg: not supported K shape");
1291 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1292 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1296 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1297 bool useF32 =
false)
const {
1298 auto getWgmmaType = [=](
Type elemType) {
1300 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1301 if (elemType.
isF16())
1302 return NVVM::WGMMATypes::f16;
1304 return NVVM::WGMMATypes::bf16;
1305 if (isa<Float8E4M3FNType>(elemType))
1306 return NVVM::WGMMATypes::e4m3;
1307 if (isa<Float8E5M2Type>(elemType))
1308 return NVVM::WGMMATypes::e5m2;
1310 return NVVM::WGMMATypes::b1;
1312 return NVVM::WGMMATypes::s8;
1314 return NVVM::WGMMATypes::u8;
1316 return NVVM::WGMMATypes::s32;
1317 llvm_unreachable(
"unsupported type");
1324 generateWgmmaLayout(std::optional<bool> transpose)
const {
1325 if (transpose.value_or(
false))
1331 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1336 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1338 NVVM::WGMMAScaleOut::one);
1341 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1343 NVVM::WGMMAScaleIn::one);
1348 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1369 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1370 MemRefType matrixTypeA = op.getDescriptorA().
getType().getTensor();
1371 Type elemA = matrixTypeA.getElementType();
1373 int tileShapeA = matrixTypeA.getDimSize(1);
1374 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1376 LDBG() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1377 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1381 return makeAdd(desc, makeI64Const(b, incrementVal));
1395 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1396 MemRefType matrixTypeB = op.getDescriptorB().
getType().getTensor();
1397 Type elemB = matrixTypeB.getElementType();
1399 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1401 LDBG() <<
"Descriptor B + " << incrementVal;
1404 return makeAdd(desc, makeI64Const(b, incrementVal));
1409 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1410 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1411 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1412 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1413 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1414 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1415 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1417 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1418 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1420 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1421 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1423 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1424 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1426 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1427 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1429 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1430 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1431 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1432 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1433 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1436 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1438 return NVVM::WgmmaMmaAsyncOp::create(
1439 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1440 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1445 Value generateWgmmaGroup() {
1447 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1451 for (
int i = 0; i < iterationM; ++i) {
1453 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1454 for (
int j = 0;
j < iterationN; ++
j)
1455 for (
int k = 0; k < iterationK; ++k)
1456 matrixC = generateWgmma(i,
j, k, matrixC);
1457 wgmmaResults.push_back(matrixC);
1460 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1461 wgmmaResult, matrix, idx);
1469 : op(op), b(b), adaptor(adaptor) {
1471 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1472 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1473 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1474 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1475 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1481 op.getDescriptorA().getType().getTensor().getElementType());
1484 iterationM = totalM / wgmmaM;
1485 iterationN = totalN / wgmmaN;
1486 iterationK = totalK / wgmmaK;
1494 Value generateWarpgroupMma() {
1495 NVVM::WgmmaFenceAlignedOp::create(b);
1496 Value wgmmaResult = generateWgmmaGroup();
1497 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1498 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1503 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1508 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1511 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1519 struct NVGPUWarpgroupMmaStoreOpLowering
1565 auto makeConst = [&](int32_t index) ->
Value {
1568 Value c1 = makeConst(1);
1569 Value c2 = makeConst(2);
1570 Value c4 = makeConst(4);
1571 Value c8 = makeConst(8);
1572 Value c16 = makeConst(16);
1576 return LLVM::MulOp::create(b, lhs.
getType(), lhs, rhs);
1579 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1582 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1585 Value idx = arith::IndexCastOp::create(b, it, x);
1586 Value idy0 = arith::IndexCastOp::create(b, it, y);
1587 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1588 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1589 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1590 memref::StoreOp::create(b, d0, memref,
ValueRange{idx, idy0});
1591 memref::StoreOp::create(b, d1, memref,
ValueRange{idx, idy1});
1594 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1595 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1596 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1597 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1598 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1600 Value tj = makeMul(lane4modId, c2);
1601 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1603 ti = makeAdd(ti, makeConst(offset));
1605 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1608 constexpr
unsigned numAdjacentRegisters = 2;
1610 constexpr
unsigned numStackedMatrices = 2;
1612 size_t storeCount = (structType.getBody().size() /
1613 (numStackedMatrices * numAdjacentRegisters));
1615 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1616 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1617 for (
size_t j = 0;
j < storeCount; ++
j) {
1618 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1619 size_t structIndex = (i * numAdjacentRegisters) +
1620 (
j * (numStackedMatrices * numAdjacentRegisters));
1621 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1627 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1631 Value matriDValue = adaptor.getMatrixD();
1632 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1634 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1635 Value innerStructValue =
1636 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1637 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1638 offset += structType.getBody().size();
1645 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1650 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1653 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1654 getTypeConverter()->convertType(op.getMatrixC().getType()));
1655 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1658 Value zero = LLVM::ConstantOp::create(b, elemType, b.
getZeroAttr(elemType));
1659 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1663 auto structType = cast<LLVM::LLVMStructType>(s);
1664 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1665 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1666 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1669 innerStructs.push_back(structValue);
1673 packStruct = LLVM::InsertValueOp::create(b, packStruct.
getType(),
1674 packStruct, matrix, idx);
1681 struct NVGPUTmaFenceOpLowering
1685 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1690 Value tensormapSize =
1697 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1703 struct NVGPUTmaPrefetchOpLowering
1707 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1710 op,
nullptr,
nullptr,
1711 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1720 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1725 VectorType inTy = op.getIn().getType();
1727 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1728 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1729 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1730 for (
int i = 0; i < numElems; i++) {
1732 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1733 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1734 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1738 if (inTy.getRank() == 1) {
1739 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1743 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1745 OpAdaptor adaptor(operands);
1746 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1756 NVGPUMBarrierCreateLowering,
1757 NVGPUMBarrierInitLowering,
1758 NVGPUMBarrierGetLowering,
1759 NVGPUMBarrierArriveLowering,
1760 NVGPUMBarrierArriveNoCompleteLowering,
1761 NVGPUMBarrierTestWaitLowering,
1762 NVGPUMBarrierTryWaitParityLowering,
1763 NVGPUTmaAsyncLoadOpLowering,
1764 NVGPUTmaAsyncStoreOpLowering,
1765 NVGPUTmaCreateDescriptorOpLowering,
1766 NVGPUTmaPrefetchOpLowering,
1767 NVGPUTmaFenceOpLowering,
1768 NVGPUMBarrierArriveExpectTxLowering,
1769 NVGPUGenerateWarpgroupDescriptorLowering,
1770 NVGPUWarpgroupMmaOpLowering,
1771 NVGPUWarpgroupMmaStoreOpLowering,
1772 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1773 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1774 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1775 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.