28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
37 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
38 #include "mlir/Conversion/Passes.h.inc"
51 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
54 return LLVM::TruncOp::create(b, b.
getI32Type(), value);
61 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
69 if (a.getElementType() == f16x2Ty) {
70 return LLVM::LLVMStructType::getLiteral(
73 if (a.getElementType() == i32x2Ty) {
74 return LLVM::LLVMStructType::getLiteral(
78 if (a.getElementType() == f64x2Ty) {
79 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
81 if (a.getElementType() == f32x2Ty) {
82 return LLVM::LLVMStructType::getLiteral(
87 return LLVM::LLVMStructType::getLiteral(
90 return vectorResultType;
102 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
103 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
113 auto makeConst = [&](int32_t index) ->
Value {
123 if (arrayType.getElementType() == f16x2Ty ||
124 arrayType.getElementType() == f32x1Ty) {
125 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
127 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i);
129 loc, arrayType.getElementType(), el);
130 elements.push_back(el);
138 if (arrayType.getElementType() == i32x2Ty ||
139 arrayType.getElementType() == f64x2Ty ||
140 arrayType.getElementType() == f32x2Ty) {
142 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType());
146 LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2);
147 Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult,
149 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
151 vec = LLVM::InsertElementOp::create(rewriter, loc, vec.
getType(), vec,
153 elements.push_back(vec);
158 Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType);
160 result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(),
166 return intrinsicResult;
176 NVVM::MMATypes operandPtxType) {
185 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
187 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
188 Value toUse = LLVM::ExtractValueOp::create(b, operand, i);
192 if (arrayTy.getElementType() == i8x4Ty ||
193 arrayTy.getElementType() == i4x8Ty ||
194 (arrayTy.getElementType() == f32x1Ty &&
195 operandPtxType == NVVM::MMATypes::tf32)) {
196 result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse));
203 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
204 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
205 innerArrayTy.getElementType() == f64Ty ||
206 innerArrayTy.getElementType() == f32Ty)) {
207 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
208 idx < innerSize; idx++) {
209 result.push_back(LLVM::ExtractElementOp::create(
215 result.push_back(toUse);
222 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
223 barrierType.getMemorySpace()));
228 nvgpu::MBarrierGroupType barrierType) {
233 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
241 nvgpu::MBarrierGroupType barrierType) {
243 MemRefLayoutAttrInterface layout;
254 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
265 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
266 if (!vectorResultType) {
270 vectorResultType.getElementType());
272 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274 Type ldMatrixResultType;
275 if (num32BitRegs > 1) {
276 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
282 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 adaptor.getSrcMemref(), adaptor.getIndices());
287 Value ldMatrixResult = NVVM::LdMatrixOp::create(
288 b, ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row,
292 shape, NVVM::LdStMatrixEltType::B16);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = LLVM::PoisonOp::create(b, finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i)
304 Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register);
305 result = LLVM::InsertValueOp::create(b, result, casted, i);
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult =
379 NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{
387 NVVM::MMALayout::row, NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395 struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 void runOnOperation()
override {
405 converter, [](gpu::AddressSpace space) ->
unsigned {
407 case gpu::AddressSpace::Global:
408 return static_cast<unsigned>(
410 case gpu::AddressSpace::Workgroup:
411 return static_cast<unsigned>(
413 case gpu::AddressSpace::Private:
416 llvm_unreachable(
"unknown address space enum value");
422 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
425 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
426 Type elemType = type.getFragmented().getElementType();
427 int64_t sizeM = type.getFragmented().getDimSize(0);
428 int64_t sizeN = type.getFragmented().getDimSize(1);
432 numMembers = sizeN / 2;
433 else if (elemType.
isF16())
434 numMembers = sizeN / 4;
436 llvm_unreachable(
"unsupported type for warpgroup accumulator");
439 for (
unsigned i = 0; i < numMembers; i++)
440 innerStructBody.push_back(elemType);
441 auto innerStructType =
442 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
446 structBody.push_back(innerStructType);
449 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
450 return converter.convertType(convertedType);
452 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
455 converter.addConversion(
456 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
459 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
460 return converter.convertType(
463 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
468 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
469 target.addLegalDialect<::mlir::arith::ArithDialect>();
470 target.addLegalDialect<::mlir::memref::MemRefDialect>();
471 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
481 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
485 llvm::raw_string_ostream ss(str);
486 for (
unsigned i = 0; i < matCSize; i++)
488 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
500 static std::string buildMmaSparseAsmString(
501 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
502 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
503 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
504 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
505 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
506 return NVVM::stringifyMMATypes(ptxType);
510 llvm::raw_string_ostream ss(asmStr);
511 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
512 << shape[2] <<
".row.col.";
515 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
517 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
518 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
519 unsigned asmArgIdx = 0;
523 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
525 for (
unsigned i = 0; i < arrSize; i++)
526 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
529 ss <<
"$" << asmArgIdx++ <<
",";
530 assert(metaDataSelector <= 1);
531 ss <<
"0x" << metaDataSelector <<
";";
537 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
539 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
540 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
542 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
543 Type intrinsicResultType) {
544 auto asmDialectAttr =
547 const unsigned matASize = unpackedAData.size();
548 const unsigned matBSize = unpackedB.size();
549 const unsigned matCSize = unpackedC.size();
551 std::string asmStr = buildMmaSparseAsmString(
552 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
553 ptxTypeD, overflow, metadataSelector);
554 std::string constraintStr =
555 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
558 asmVals.reserve(matASize + matBSize + matCSize + 1);
560 llvm::append_range(asmVals, args);
561 asmVals.push_back(indexData);
563 return LLVM::InlineAsmOp::create(b,
576 struct NVGPUMmaSparseSyncLowering
581 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586 VectorType aType = op.getMatrixA().getType();
587 VectorType bType = op.getMatrixB().getType();
588 VectorType cType = op.getMatrixC().getType();
590 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
592 return op->emitOpError(
"failed to deduce operand PTX types");
593 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
595 return op->emitOpError(
"failed to deduce operand PTX types");
596 std::optional<NVVM::MMATypes> ptxTypeC =
597 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
600 return op->emitError(
601 "could not infer the PTX type for the accumulator/result");
604 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
605 if (aType.getElementType().isF32() && !tf32Enabled)
609 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
610 if (isa<IntegerType>(aType.getElementType()))
611 overflow = NVVM::MMAIntOverflow::satfinite;
620 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
622 typeConverter->convertType(op->getResultTypes()[0]));
625 Value sparseMetadata = adaptor.getSparseMetadata();
627 return op->emitOpError() <<
"Expected metadata type to be LLVM "
628 "VectorType of 2 i16 elements";
630 LLVM::BitcastOp::create(b, rewriter.
getI32Type(), sparseMetadata);
632 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
633 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
634 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
636 if (
failed(intrinsicResult))
639 assert((*intrinsicResult).getNumResults() == 1 &&
640 "expected inline asm op returns a single LLVM struct type");
643 (*intrinsicResult)->getResult(0), rewriter));
648 struct NVGPUAsyncCopyLowering
654 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
658 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
661 adaptor.getDst(), adaptor.getDstIndices());
662 FailureOr<unsigned> dstAddressSpace =
663 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
664 if (
failed(dstAddressSpace))
666 loc,
"destination memref address space not convertible to integer");
668 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
669 FailureOr<unsigned> srcAddressSpace =
670 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
671 if (
failed(srcAddressSpace))
673 loc,
"source memref address space not convertible to integer");
677 adaptor.getSrcIndices());
681 scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr);
682 int64_t dstElements = adaptor.getDstElements().getZExtValue();
683 int64_t sizeInBytes =
684 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689 Value srcBytes = adaptor.getSrcElements();
697 Value bitwidth = LLVM::ConstantOp::create(
700 Value srcElementsI32 = LLVM::TruncOp::create(b, b.
getI32Type(), srcBytes);
701 srcBytes = LLVM::LShrOp::create(
702 b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32);
706 NVVM::LoadCacheModifierKind cacheModifier =
707 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
708 ? NVVM::LoadCacheModifierKind::CG
709 : NVVM::LoadCacheModifierKind::CA;
711 NVVM::CpAsyncOp::create(
725 struct NVGPUAsyncCreateGroupLowering
731 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
733 NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc());
735 Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
743 struct NVGPUAsyncWaitLowering
749 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
752 int32_t numGroups = adaptor.getNumGroups().value_or(0);
753 NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups);
760 struct NVGPUMBarrierCreateLowering
764 template <
typename moduleT>
767 MemRefType barrierType)
const {
771 auto global = memref::GlobalOp::create(
772 rewriter, funcOp->
getLoc(),
"__mbarrier",
778 symbolTable.insert(global);
783 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787 rewriter.
getContext(), op.getBarriers().getType());
789 memref::GlobalOp global;
791 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
793 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
803 template <
typename SourceOp>
809 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
812 MemRefType mbarrierMemrefType =
815 rewriter, b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
819 struct NVGPUMBarrierGetLowering
820 :
public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
821 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
824 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
827 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
829 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
830 adaptor.getMbarId(), rewriter);
831 Type resType = op.getMbarrierPointer().getType();
838 struct NVGPUMBarrierInitLowering
839 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
840 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
843 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
846 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
848 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
849 adaptor.getMbarId(), rewriter);
853 op, barrier, count, adaptor.getPredicate());
856 adaptor.getPredicate());
863 struct NVGPUMBarrierArriveLowering
864 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
865 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
867 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
871 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
872 adaptor.getMbarId(), rewriter);
873 Type tokenType = getTypeConverter()->convertType(
888 struct NVGPUMBarrierArriveNoCompleteLowering
889 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
890 using MBarrierBasePattern<
891 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
893 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
897 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
898 adaptor.getMbarId(), rewriter);
899 Type tokenType = getTypeConverter()->convertType(
904 op, tokenType, barrier, count);
907 op, tokenType, barrier, count);
914 struct NVGPUMBarrierTestWaitLowering
915 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
916 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
918 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
922 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923 adaptor.getMbarId(), rewriter);
927 op, retType, barrier, adaptor.getToken());
930 op, retType, barrier, adaptor.getToken());
936 struct NVGPUMBarrierArriveExpectTxLowering
937 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
938 using MBarrierBasePattern<
939 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
941 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
945 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
946 adaptor.getMbarId(), rewriter);
951 op, barrier, txcount, adaptor.getPredicate());
956 op, barrier, txcount, adaptor.getPredicate());
961 struct NVGPUMBarrierTryWaitParityLowering
962 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
963 using MBarrierBasePattern<
964 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
966 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
970 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
971 adaptor.getMbarId(), rewriter);
974 LLVM::ZExtOp::create(b, b.
getI32Type(), adaptor.getPhaseParity());
978 op, barrier, phase, ticks);
988 struct NVGPUTmaAsyncLoadOpLowering
989 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
990 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
992 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
995 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
997 adaptor.getDst(), {});
999 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1000 adaptor.getMbarId(), rewriter);
1007 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1009 adaptor.getPredicate());
1014 struct NVGPUTmaAsyncStoreOpLowering
1015 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1016 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1018 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1021 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1023 adaptor.getSrc(), {});
1031 op, adaptor.getTensorMapDescriptor(), dest, coords,
Value{},
1032 NVVM::TMAStoreMode::TILE,
1033 adaptor.getPredicate());
1038 struct NVGPUGenerateWarpgroupDescriptorLowering
1044 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1049 nvgpu::TensorMapSwizzleKind swizzleKind =
1050 op.getTensorMap().getType().getSwizzle();
1053 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1054 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1055 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1058 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1059 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1060 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1064 auto makeConst = [&](uint64_t index) ->
Value {
1067 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1068 return LLVM::ShlOp::create(b, ti64, value, makeConst(shift));
1070 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1071 return LLVM::LShrOp::create(b, ti64, value, makeConst(shift));
1073 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1074 return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit));
1077 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1078 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1079 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1080 uint64_t offsetVal = 0;
1082 Value strideDim = makeConst(strideDimVal);
1083 Value leadDim = makeConst(leadDimVal);
1086 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1087 adaptor.getTensor(), {});
1088 Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr);
1090 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1092 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1093 startLeadBit = 16, startBaseAddrBit = 0;
1094 Value dsc = makeConst(0);
1096 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1098 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1100 dsc = insertBit(dsc, strideDim, startStrideBit);
1102 dsc = insertBit(dsc, leadDim, startLeadBit);
1104 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1106 LDBG() <<
"Generating warpgroup.descriptor: " <<
"leading_off:"
1107 << leadDimVal <<
"\t" <<
"stride_off :" << strideDimVal <<
"\t"
1108 <<
"base_offset:" << offsetVal <<
"\t" <<
"layout_type:" << swizzle
1109 <<
" (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1110 <<
")\n start_addr : " << baseAddr;
1126 enum CUtensorMapDataTypeEnum {
1127 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1128 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1129 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1130 CU_TENSOR_MAP_DATA_TYPE_INT32,
1131 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1132 CU_TENSOR_MAP_DATA_TYPE_INT64,
1133 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1134 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1135 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1136 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1137 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1138 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1139 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1143 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1145 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1147 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1149 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1151 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1153 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1155 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1157 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1159 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1161 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1163 llvm_unreachable(
"Not supported data type");
1166 struct NVGPUTmaCreateDescriptorOpLowering
1171 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1177 Value tensorElementType =
1178 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1179 auto promotedOperands = getTypeConverter()->promoteOperands(
1180 b.
getLoc(), op->getOperands(), adaptor.getOperands(), b);
1182 Value boxArrayPtr = LLVM::AllocaOp::create(
1183 b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5));
1184 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1185 Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType,
1186 boxArrayPtr, makeI64Const(b, index));
1187 LLVM::StoreOp::create(b, value, gep);
1190 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1193 arguments.push_back(promotedOperands[0]);
1194 arguments.push_back(promotedOperands[1]);
1195 arguments.push_back(tensorElementType);
1196 arguments.push_back(
1197 makeI64Const(b, (
int)desc.getInterleave()));
1198 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1199 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1200 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1201 arguments.push_back(boxArrayPtr);
1215 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1217 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1224 struct NVGPUWarpgroupMmaOpLowering
1248 class WarpgroupGemm {
1249 nvgpu::WarpgroupMmaOp op;
1254 int64_t totalM, totalN, totalK;
1257 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1260 int iterationM = 0, iterationN = 0, iterationK = 0;
1265 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1268 if (inputElemType.
isTF32()) {
1270 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1272 }
else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1275 }
else if (inputElemType.
isInteger(1)) {
1278 llvm_unreachable(
"msg: not supported K shape");
1280 LDBG() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1281 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]";
1285 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1286 bool useF32 =
false)
const {
1287 auto getWgmmaType = [=](
Type elemType) {
1289 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1290 if (elemType.
isF16())
1291 return NVVM::WGMMATypes::f16;
1293 return NVVM::WGMMATypes::bf16;
1294 if (isa<Float8E4M3FNType>(elemType))
1295 return NVVM::WGMMATypes::e4m3;
1296 if (isa<Float8E5M2Type>(elemType))
1297 return NVVM::WGMMATypes::e5m2;
1299 return NVVM::WGMMATypes::b1;
1301 return NVVM::WGMMATypes::s8;
1303 return NVVM::WGMMATypes::u8;
1305 return NVVM::WGMMATypes::s32;
1306 llvm_unreachable(
"unsupported type");
1313 generateWgmmaLayout(std::optional<bool> transpose)
const {
1314 if (transpose.value_or(
false))
1320 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1325 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1327 NVVM::WGMMAScaleOut::one);
1330 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1332 NVVM::WGMMAScaleIn::one);
1337 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1358 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1359 MemRefType matrixTypeA = op.getDescriptorA().
getType().getTensor();
1360 Type elemA = matrixTypeA.getElementType();
1362 int tileShapeA = matrixTypeA.getDimSize(1);
1363 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1365 LDBG() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1366 <<
"] [wgmma descriptors] Descriptor A + " << incrementVal
1370 return makeAdd(desc, makeI64Const(b, incrementVal));
1384 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1385 MemRefType matrixTypeB = op.getDescriptorB().
getType().getTensor();
1386 Type elemB = matrixTypeB.getElementType();
1388 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1390 LDBG() <<
"Descriptor B + " << incrementVal;
1393 return makeAdd(desc, makeI64Const(b, incrementVal));
1398 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1399 LDBG() <<
"\t wgmma." <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1400 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1401 << (iterationM * wgmmaM) + wgmmaM <<
"][" << (iterationK * wgmmaK)
1402 <<
":" << (iterationK * wgmmaK + wgmmaK) <<
"] * " <<
" B["
1403 << (iterationK * wgmmaK) <<
":" << (iterationK * wgmmaK + wgmmaK)
1404 <<
"][" << 0 <<
":" << wgmmaN <<
"])";
1406 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1407 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1409 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1410 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1412 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1413 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1415 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1416 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1418 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1419 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1420 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1421 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1422 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1425 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1427 return NVVM::WgmmaMmaAsyncOp::create(
1428 b, matrixC.
getType(), matrixC, descriptorA, descriptorB, shape,
1429 itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1434 Value generateWgmmaGroup() {
1436 LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType());
1440 for (
int i = 0; i < iterationM; ++i) {
1442 LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i);
1443 for (
int j = 0;
j < iterationN; ++
j)
1444 for (
int k = 0; k < iterationK; ++k)
1445 matrixC = generateWgmma(i,
j, k, matrixC);
1446 wgmmaResults.push_back(matrixC);
1449 wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.
getType(),
1450 wgmmaResult, matrix, idx);
1458 : op(op), b(b), adaptor(adaptor) {
1460 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1461 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1462 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1463 LDBG() <<
"===--- GEMM D[" << totalM <<
"][" << totalN <<
"] += A["
1464 << totalM <<
"][" << totalK <<
"] * B[" << totalK <<
"][" << totalN
1470 op.getDescriptorA().getType().getTensor().getElementType());
1473 iterationM = totalM / wgmmaM;
1474 iterationN = totalN / wgmmaN;
1475 iterationK = totalK / wgmmaK;
1483 Value generateWarpgroupMma() {
1484 NVVM::WgmmaFenceAlignedOp::create(b);
1485 Value wgmmaResult = generateWgmmaGroup();
1486 NVVM::WgmmaGroupSyncAlignedOp::create(b);
1487 NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup());
1492 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1497 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1500 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1508 struct NVGPUWarpgroupMmaStoreOpLowering
1554 auto makeConst = [&](int32_t index) ->
Value {
1557 Value c1 = makeConst(1);
1558 Value c2 = makeConst(2);
1559 Value c4 = makeConst(4);
1560 Value c8 = makeConst(8);
1561 Value c16 = makeConst(16);
1565 return LLVM::MulOp::create(b, lhs.
getType(), lhs, rhs);
1568 return LLVM::AddOp::create(b, lhs.
getType(), lhs, rhs);
1571 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1574 Value idx = arith::IndexCastOp::create(b, it, x);
1575 Value idy0 = arith::IndexCastOp::create(b, it, y);
1576 Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1));
1577 Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i);
1578 Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1);
1579 memref::StoreOp::create(b, d0, memref,
ValueRange{idx, idy0});
1580 memref::StoreOp::create(b, d1, memref,
ValueRange{idx, idy1});
1583 Value tidx = NVVM::ThreadIdXOp::create(b, i32);
1584 Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize);
1585 Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize);
1586 Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4);
1587 Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4);
1589 Value tj = makeMul(lane4modId, c2);
1590 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1592 ti = makeAdd(ti, makeConst(offset));
1594 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1597 constexpr
unsigned numAdjacentRegisters = 2;
1599 constexpr
unsigned numStackedMatrices = 2;
1601 size_t storeCount = (structType.getBody().size() /
1602 (numStackedMatrices * numAdjacentRegisters));
1604 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1605 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1606 for (
size_t j = 0;
j < storeCount; ++
j) {
1607 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1608 size_t structIndex = (i * numAdjacentRegisters) +
1609 (
j * (numStackedMatrices * numAdjacentRegisters));
1610 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1616 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1620 Value matriDValue = adaptor.getMatrixD();
1621 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1623 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1624 Value innerStructValue =
1625 LLVM::ExtractValueOp::create(b, matriDValue, idx);
1626 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1627 offset += structType.getBody().size();
1634 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1639 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1642 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1643 getTypeConverter()->convertType(op.getMatrixC().getType()));
1644 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1647 Value zero = LLVM::ConstantOp::create(b, elemType, b.
getZeroAttr(elemType));
1648 Value packStruct = LLVM::PoisonOp::create(b, packStructType);
1652 auto structType = cast<LLVM::LLVMStructType>(s);
1653 Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx);
1654 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1655 structValue = LLVM::InsertValueOp::create(b, structType, structValue,
1658 innerStructs.push_back(structValue);
1662 packStruct = LLVM::InsertValueOp::create(b, packStruct.
getType(),
1663 packStruct, matrix, idx);
1670 struct NVGPUTmaFenceOpLowering
1674 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1679 Value tensormapSize =
1686 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1692 struct NVGPUTmaPrefetchOpLowering
1696 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1699 op,
nullptr,
nullptr,
1700 adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1709 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1714 VectorType inTy = op.getIn().getType();
1716 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1717 Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy);
1718 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1719 for (
int i = 0; i < numElems; i++) {
1721 Value elem = LLVM::ExtractElementOp::create(b, inVec, idx);
1722 Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem);
1723 ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx);
1727 if (inTy.getRank() == 1) {
1728 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1732 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1734 OpAdaptor adaptor(operands);
1735 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1745 NVGPUMBarrierCreateLowering,
1746 NVGPUMBarrierInitLowering,
1747 NVGPUMBarrierGetLowering,
1748 NVGPUMBarrierArriveLowering,
1749 NVGPUMBarrierArriveNoCompleteLowering,
1750 NVGPUMBarrierTestWaitLowering,
1751 NVGPUMBarrierTryWaitParityLowering,
1752 NVGPUTmaAsyncLoadOpLowering,
1753 NVGPUTmaAsyncStoreOpLowering,
1754 NVGPUTmaCreateDescriptorOpLowering,
1755 NVGPUTmaPrefetchOpLowering,
1756 NVGPUTmaFenceOpLowering,
1757 NVGPUMBarrierArriveExpectTxLowering,
1758 NVGPUGenerateWarpgroupDescriptorLowering,
1759 NVGPUWarpgroupMmaOpLowering,
1760 NVGPUWarpgroupMmaStoreOpLowering,
1761 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1762 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1763 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1764 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.