29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
53 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
63 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
71 if (a.getElementType() == f16x2Ty) {
72 return LLVM::LLVMStructType::getLiteral(
75 if (a.getElementType() == i32x2Ty) {
76 return LLVM::LLVMStructType::getLiteral(
80 if (a.getElementType() == f64x2Ty) {
81 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
83 if (a.getElementType() == f32x2Ty) {
84 return LLVM::LLVMStructType::getLiteral(
89 return LLVM::LLVMStructType::getLiteral(
92 return vectorResultType;
104 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
115 auto makeConst = [&](int32_t index) ->
Value {
125 if (arrayType.getElementType() == f16x2Ty ||
126 arrayType.getElementType() == f32x1Ty) {
127 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
129 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
131 loc, arrayType.getElementType(), el);
132 elements.push_back(el);
140 if (arrayType.getElementType() == i32x2Ty ||
141 arrayType.getElementType() == f64x2Ty ||
142 arrayType.getElementType() == f32x2Ty) {
144 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
146 rewriter.
create<LLVM::UndefOp>(loc, arrayType.getElementType());
148 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149 Value x2 = rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult,
151 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
153 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
155 elements.push_back(vec);
160 Value result = rewriter.
create<LLVM::UndefOp>(loc, arrayType);
162 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, el.value(),
168 return intrinsicResult;
178 NVVM::MMATypes operandPtxType) {
187 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
189 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190 Value toUse = b.
create<LLVM::ExtractValueOp>(operand, i);
194 if (arrayTy.getElementType() == i8x4Ty ||
195 arrayTy.getElementType() == i4x8Ty ||
196 (arrayTy.getElementType() == f32x1Ty &&
197 operandPtxType == NVVM::MMATypes::tf32)) {
198 result.push_back(b.
create<LLVM::BitcastOp>(i32Ty, toUse));
205 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207 innerArrayTy.getElementType() == f64Ty ||
208 innerArrayTy.getElementType() == f32Ty)) {
209 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210 idx < innerSize; idx++) {
211 result.push_back(b.
create<LLVM::ExtractElementOp>(
217 result.push_back(toUse);
224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225 barrierType.getMemorySpace()));
230 nvgpu::MBarrierGroupType barrierType) {
235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
243 nvgpu::MBarrierGroupType barrierType) {
245 MemRefLayoutAttrInterface layout;
256 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
267 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268 if (!vectorResultType) {
272 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
274 int64_t num32BitRegs = vectorResultType.getDimSize(0);
276 Type ldMatrixResultType;
277 if (num32BitRegs > 1) {
278 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
286 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287 adaptor.getIndices(), rewriter);
288 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289 ldMatrixResultType, srcPtr,
291 op.getTranspose() ? NVVM::MMALayout::col
292 : NVVM::MMALayout::row);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = b.create<LLVM::UndefOp>(finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
304 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305 result = b.create<LLVM::InsertValueOp>(result, casted, i);
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult = b.create<NVVM::MmaOp>(
379 intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387 NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395 struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
400 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401 arith::ArithDialect>();
404 void runOnOperation()
override {
410 converter, [](gpu::AddressSpace space) ->
unsigned {
412 case gpu::AddressSpace::Global:
413 return static_cast<unsigned>(
415 case gpu::AddressSpace::Workgroup:
416 return static_cast<unsigned>(
418 case gpu::AddressSpace::Private:
421 llvm_unreachable(
"unknown address space enum value");
427 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
430 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
431 Type elemType = type.getFragmented().getElementType();
432 int64_t sizeM = type.getFragmented().getDimSize(0);
433 int64_t sizeN = type.getFragmented().getDimSize(1);
437 numMembers = sizeN / 2;
438 else if (elemType.
isF16())
439 numMembers = sizeN / 4;
441 llvm_unreachable(
"unsupported type for warpgroup accumulator");
444 for (
unsigned i = 0; i < numMembers; i++)
445 innerStructBody.push_back(elemType);
446 auto innerStructType =
447 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
451 structBody.push_back(innerStructType);
454 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455 return converter.convertType(convertedType);
457 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
460 converter.addConversion(
461 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
464 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
465 return converter.convertType(
468 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
473 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474 target.addLegalDialect<::mlir::arith::ArithDialect>();
475 target.addLegalDialect<::mlir::memref::MemRefDialect>();
476 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
486 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
490 llvm::raw_string_ostream ss(str);
491 for (
unsigned i = 0; i < matCSize; i++)
493 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
505 static std::string buildMmaSparseAsmString(
506 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511 return NVVM::stringifyMMATypes(ptxType);
515 llvm::raw_string_ostream ss(asmStr);
516 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
517 << shape[2] <<
".row.col.";
520 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
522 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
523 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
524 unsigned asmArgIdx = 0;
528 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
530 for (
unsigned i = 0; i < arrSize; i++)
531 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
534 ss <<
"$" << asmArgIdx++ <<
",";
535 assert(metaDataSelector <= 1);
536 ss <<
"0x" << metaDataSelector <<
";";
542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
544 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
547 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
548 Type intrinsicResultType) {
549 auto asmDialectAttr =
552 const unsigned matASize = unpackedAData.size();
553 const unsigned matBSize = unpackedB.size();
554 const unsigned matCSize = unpackedC.size();
556 std::string asmStr = buildMmaSparseAsmString(
557 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558 ptxTypeD, overflow, metadataSelector);
559 std::string constraintStr =
560 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
563 asmVals.reserve(matASize + matBSize + matCSize + 1);
565 llvm::append_range(asmVals, args);
566 asmVals.push_back(indexData);
568 return b.
create<LLVM::InlineAsmOp>(
580 struct NVGPUMmaSparseSyncLowering
585 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
590 VectorType aType = op.getMatrixA().getType();
591 VectorType bType = op.getMatrixB().getType();
592 VectorType cType = op.getMatrixC().getType();
594 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595 if (failed(ptxTypeA))
596 return op->emitOpError(
"failed to deduce operand PTX types");
597 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598 if (failed(ptxTypeB))
599 return op->emitOpError(
"failed to deduce operand PTX types");
600 std::optional<NVVM::MMATypes> ptxTypeC =
601 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
604 return op->emitError(
605 "could not infer the PTX type for the accumulator/result");
608 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609 if (aType.getElementType().isF32() && !tf32Enabled)
613 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614 if (isa<IntegerType>(aType.getElementType()))
615 overflow = NVVM::MMAIntOverflow::satfinite;
624 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
626 typeConverter->convertType(op->getResultTypes()[0]));
629 Value sparseMetadata = adaptor.getSparseMetadata();
630 if (sparseMetadata.
getType() !=
632 return op->emitOpError() <<
"Expected metadata type to be LLVM "
633 "VectorType of 2 i16 elements";
637 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
638 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
639 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
641 if (failed(intrinsicResult))
644 assert((*intrinsicResult).getNumResults() == 1 &&
645 "expected inline asm op returns a single LLVM struct type");
648 (*intrinsicResult)->getResult(0), rewriter));
653 struct NVGPUAsyncCopyLowering
659 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
663 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
665 getStridedElementPtr(b.
getLoc(), dstMemrefType, adaptor.getDst(),
666 adaptor.getDstIndices(), rewriter);
667 FailureOr<unsigned> dstAddressSpace =
668 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
669 if (failed(dstAddressSpace))
671 loc,
"destination memref address space not convertible to integer");
673 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
674 FailureOr<unsigned> srcAddressSpace =
675 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
676 if (failed(srcAddressSpace))
678 loc,
"source memref address space not convertible to integer");
680 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
681 adaptor.getSrcIndices(), rewriter);
685 scrPtr = b.
create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686 int64_t dstElements = adaptor.getDstElements().getZExtValue();
687 int64_t sizeInBytes =
688 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
693 Value srcBytes = adaptor.getSrcElements();
705 srcBytes = b.
create<LLVM::LShrOp>(
706 b.
create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
710 NVVM::LoadCacheModifierKind cacheModifier =
711 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
712 ? NVVM::LoadCacheModifierKind::CG
713 : NVVM::LoadCacheModifierKind::CA;
715 b.
create<NVVM::CpAsyncOp>(
728 struct NVGPUAsyncCreateGroupLowering
734 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
736 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
746 struct NVGPUAsyncWaitLowering
752 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
755 int32_t numGroups = adaptor.getNumGroups().value_or(0);
756 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
763 struct NVGPUMBarrierCreateLowering
767 template <
typename moduleT>
770 MemRefType barrierType)
const {
774 auto global = rewriter.
create<memref::GlobalOp>(
775 funcOp->
getLoc(),
"__mbarrier",
781 symbolTable.insert(global);
786 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
790 rewriter.
getContext(), op.getBarriers().getType());
792 memref::GlobalOp global;
794 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
796 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
806 template <
typename SourceOp>
812 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
815 MemRefType mbarrierMemrefType =
818 b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
823 struct NVGPUMBarrierInitLowering
824 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
825 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
828 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
831 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
833 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
834 adaptor.getMbarId(), rewriter);
838 op, barrier, count, adaptor.getPredicate());
841 adaptor.getPredicate());
848 struct NVGPUMBarrierArriveLowering
849 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
850 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
852 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
856 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
857 adaptor.getMbarId(), rewriter);
858 Type tokenType = getTypeConverter()->convertType(
873 struct NVGPUMBarrierArriveNoCompleteLowering
874 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
875 using MBarrierBasePattern<
876 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
878 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
882 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
883 adaptor.getMbarId(), rewriter);
884 Type tokenType = getTypeConverter()->convertType(
889 op, tokenType, barrier, count);
892 op, tokenType, barrier, count);
899 struct NVGPUMBarrierTestWaitLowering
900 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
901 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
903 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
907 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
908 adaptor.getMbarId(), rewriter);
912 op, retType, barrier, adaptor.getToken());
915 op, retType, barrier, adaptor.getToken());
921 struct NVGPUMBarrierArriveExpectTxLowering
922 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
923 using MBarrierBasePattern<
924 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
926 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
930 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
931 adaptor.getMbarId(), rewriter);
936 op, barrier, txcount, adaptor.getPredicate());
941 op, barrier, txcount, adaptor.getPredicate());
946 struct NVGPUMBarrierTryWaitParityLowering
947 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
948 using MBarrierBasePattern<
949 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
951 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
955 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
956 adaptor.getMbarId(), rewriter);
963 op, barrier, phase, ticks);
973 struct NVGPUTmaAsyncLoadOpLowering
974 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
975 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
977 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
980 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
981 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
982 adaptor.getDst(), {}, rewriter);
984 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
985 adaptor.getMbarId(), rewriter);
992 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
994 adaptor.getPredicate());
999 struct NVGPUTmaAsyncStoreOpLowering
1000 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1001 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1003 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1006 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1007 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1008 adaptor.getSrc(), {}, rewriter);
1015 op, adaptor.getTensorMapDescriptor(), dest, coords,
1016 adaptor.getPredicate());
1021 struct NVGPUGenerateWarpgroupDescriptorLowering
1027 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1032 nvgpu::TensorMapSwizzleKind swizzleKind =
1033 op.getTensorMap().getType().getSwizzle();
1036 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1037 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1038 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1041 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1042 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1043 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1047 auto makeConst = [&](uint64_t index) ->
Value {
1050 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1051 return b.
create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1053 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1054 return b.
create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1056 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1057 return b.
create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1060 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1061 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1062 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1063 uint64_t offsetVal = 0;
1065 Value strideDim = makeConst(strideDimVal);
1066 Value leadDim = makeConst(leadDimVal);
1068 Value baseAddr = getStridedElementPtr(
1069 op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1070 adaptor.getTensor(), {}, rewriter);
1071 Value basePtr = b.
create<LLVM::PtrToIntOp>(ti64, baseAddr);
1073 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1075 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1076 startLeadBit = 16, startBaseAddrBit = 0;
1077 Value dsc = makeConst(0);
1079 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1081 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1083 dsc = insertBit(dsc, strideDim, startStrideBit);
1085 dsc = insertBit(dsc, leadDim, startLeadBit);
1087 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1089 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1090 <<
"leading_off:" << leadDimVal <<
"\t"
1091 <<
"stride_off :" << strideDimVal <<
"\t"
1092 <<
"base_offset:" << offsetVal <<
"\t"
1093 <<
"layout_type:" << swizzle <<
" ("
1094 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1095 <<
")\n start_addr : " << baseAddr <<
"\n");
1111 enum CUtensorMapDataTypeEnum {
1112 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1113 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1114 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1115 CU_TENSOR_MAP_DATA_TYPE_INT32,
1116 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1117 CU_TENSOR_MAP_DATA_TYPE_INT64,
1118 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1119 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1121 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1122 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1123 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1124 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1128 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1130 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1132 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1134 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1136 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1138 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1140 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1142 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1144 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1146 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1148 llvm_unreachable(
"Not supported data type");
1151 struct NVGPUTmaCreateDescriptorOpLowering
1156 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1162 Value tensorElementType =
1163 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1164 auto promotedOperands = getTypeConverter()->promoteOperands(
1165 b.
getLoc(), op->getOperands(), adaptor.getOperands(), b);
1167 Value boxArrayPtr = b.
create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1168 makeI64Const(b, 5));
1169 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1170 Value gep = b.
create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1171 boxArrayPtr, makeI64Const(b, index));
1172 b.
create<LLVM::StoreOp>(value, gep);
1175 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().
getType();
1178 arguments.push_back(promotedOperands[0]);
1179 arguments.push_back(promotedOperands[1]);
1180 arguments.push_back(tensorElementType);
1181 arguments.push_back(
1182 makeI64Const(b, (
int)desc.getInterleave()));
1183 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1184 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1185 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1186 arguments.push_back(boxArrayPtr);
1200 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1202 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1209 struct NVGPUWarpgroupMmaOpLowering
1233 class WarpgroupGemm {
1234 nvgpu::WarpgroupMmaOp op;
1239 int64_t totalM, totalN, totalK;
1242 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1245 int iterationM = 0, iterationN = 0, iterationK = 0;
1250 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1253 if (inputElemType.
isTF32()) {
1255 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1260 }
else if (inputElemType.
isInteger(1)) {
1263 llvm_unreachable(
"msg: not supported K shape");
1265 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1266 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1270 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1271 bool useF32 =
false)
const {
1272 auto getWgmmaType = [=](
Type elemType) {
1274 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1275 if (elemType.
isF16())
1276 return NVVM::WGMMATypes::f16;
1278 return NVVM::WGMMATypes::bf16;
1280 return NVVM::WGMMATypes::e4m3;
1282 return NVVM::WGMMATypes::e5m2;
1284 return NVVM::WGMMATypes::b1;
1286 return NVVM::WGMMATypes::s8;
1288 return NVVM::WGMMATypes::u8;
1290 return NVVM::WGMMATypes::s32;
1291 llvm_unreachable(
"unsupported type");
1298 generateWgmmaLayout(std::optional<bool>
transpose)
const {
1305 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1310 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1312 NVVM::WGMMAScaleOut::one);
1315 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1317 NVVM::WGMMAScaleIn::one);
1343 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1344 MemRefType matrixTypeA = op.getDescriptorA().
getType().getTensor();
1345 Type elemA = matrixTypeA.getElementType();
1347 int tileShapeA = matrixTypeA.getDimSize(1);
1348 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1350 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1351 <<
"] [wgmma descriptors] Descriptor A + "
1352 << incrementVal <<
" | \t ");
1355 return makeAdd(desc, makeI64Const(b, incrementVal));
1369 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1370 MemRefType matrixTypeB = op.getDescriptorB().
getType().getTensor();
1371 Type elemB = matrixTypeB.getElementType();
1373 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1375 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1378 return makeAdd(desc, makeI64Const(b, incrementVal));
1383 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1384 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1385 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1386 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1387 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1388 << (iterationK * wgmmaK) <<
":"
1389 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1390 <<
" B[" << (iterationK * wgmmaK) <<
":"
1391 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1392 << wgmmaN <<
"])\n");
1394 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1395 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1397 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1398 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1400 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1401 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1403 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1404 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1406 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1407 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1408 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1409 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1410 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1413 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1415 return b.
create<NVVM::WgmmaMmaAsyncOp>(
1416 matrixC.
getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1417 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1422 Value generateWgmmaGroup() {
1424 b.
create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1428 for (
int i = 0; i < iterationM; ++i) {
1429 Value matrixC = b.
create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1430 for (
int j = 0;
j < iterationN; ++
j)
1431 for (
int k = 0; k < iterationK; ++k)
1432 matrixC = generateWgmma(i,
j, k, matrixC);
1433 wgmmaResults.push_back(matrixC);
1436 wgmmaResult = b.
create<LLVM::InsertValueOp>(wgmmaResult.
getType(),
1437 wgmmaResult, matrix, idx);
1445 : op(op), b(b), adaptor(adaptor) {
1447 totalM = op.getDescriptorA().
getType().getTensor().getDimSize(0);
1448 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1449 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1450 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1451 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1452 << totalK <<
"][" << totalN <<
"] ---===\n");
1457 op.getDescriptorA().getType().getTensor().getElementType());
1460 iterationM = totalM / wgmmaM;
1461 iterationN = totalN / wgmmaN;
1462 iterationK = totalK / wgmmaK;
1470 Value generateWarpgroupMma() {
1471 b.
create<NVVM::WgmmaFenceAlignedOp>();
1472 Value wgmmaResult = generateWgmmaGroup();
1473 b.
create<NVVM::WgmmaGroupSyncAlignedOp>();
1474 b.
create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1479 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1484 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1487 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1495 struct NVGPUWarpgroupMmaStoreOpLowering
1541 auto makeConst = [&](int32_t index) ->
Value {
1544 Value c1 = makeConst(1);
1545 Value c2 = makeConst(2);
1546 Value c4 = makeConst(4);
1547 Value c8 = makeConst(8);
1548 Value c16 = makeConst(16);
1558 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1563 Value idy1 = b.
create<arith::IndexCastOp>(it, makeAdd(y, c1));
1564 Value d0 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i);
1565 Value d1 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1571 Value laneId = b.
create<LLVM::URemOp>(i32, tidx, warpSize);
1572 Value warpId = b.
create<LLVM::UDivOp>(i32, tidx, warpSize);
1573 Value lane4Id = b.
create<LLVM::UDivOp>(i32, laneId, c4);
1574 Value lane4modId = b.
create<LLVM::URemOp>(i32, laneId, c4);
1576 Value tj = makeMul(lane4modId, c2);
1577 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1579 ti = makeAdd(ti, makeConst(offset));
1581 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1584 constexpr
unsigned numAdjacentRegisters = 2;
1586 constexpr
unsigned numStackedMatrices = 2;
1588 size_t storeCount = (structType.getBody().size() /
1589 (numStackedMatrices * numAdjacentRegisters));
1591 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1592 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1593 for (
size_t j = 0;
j < storeCount; ++
j) {
1594 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1595 size_t structIndex = (i * numAdjacentRegisters) +
1596 (
j * (numStackedMatrices * numAdjacentRegisters));
1597 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1603 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1607 Value matriDValue = adaptor.getMatrixD();
1608 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1610 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1611 Value innerStructValue = b.
create<LLVM::ExtractValueOp>(matriDValue, idx);
1612 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1613 offset += structType.getBody().size();
1620 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1625 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1628 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1629 getTypeConverter()->convertType(op.getMatrixC().getType()));
1630 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1634 Value packStruct = b.
create<LLVM::UndefOp>(packStructType);
1638 auto structType = cast<LLVM::LLVMStructType>(s);
1639 Value structValue = b.
create<LLVM::ExtractValueOp>(packStruct, idx);
1640 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1641 structValue = b.
create<LLVM::InsertValueOp>(
1644 innerStructs.push_back(structValue);
1648 packStruct = b.
create<LLVM::InsertValueOp>(packStruct.getType(),
1649 packStruct, matrix, idx);
1656 struct NVGPUTmaPrefetchOpLowering
1660 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1663 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1671 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1676 VectorType inTy = op.getIn().getType();
1678 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1679 Value ret1DVec = b.
create<LLVM::UndefOp>(llvm1DVectorTy);
1680 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1681 for (
int i = 0; i < numElems; i++) {
1683 Value elem = b.
create<LLVM::ExtractElementOp>(inVec, idx);
1684 Value dst = b.
create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1685 ret1DVec = b.
create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1689 if (inTy.getRank() == 1) {
1690 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1694 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
1696 OpAdaptor adaptor(operands);
1697 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1707 NVGPUMBarrierCreateLowering,
1708 NVGPUMBarrierInitLowering,
1709 NVGPUMBarrierArriveLowering,
1710 NVGPUMBarrierArriveNoCompleteLowering,
1711 NVGPUMBarrierTestWaitLowering,
1712 NVGPUMBarrierTryWaitParityLowering,
1713 NVGPUTmaAsyncLoadOpLowering,
1714 NVGPUTmaAsyncStoreOpLowering,
1715 NVGPUTmaCreateDescriptorOpLowering,
1716 NVGPUTmaPrefetchOpLowering,
1717 NVGPUMBarrierArriveExpectTxLowering,
1718 NVGPUGenerateWarpgroupDescriptorLowering,
1719 NVGPUWarpgroupMmaOpLowering,
1720 NVGPUWarpgroupMmaStoreOpLowering,
1721 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1722 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1723 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1724 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isFloat8E4M3FN() const
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isFloat8E5M2() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.