29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
53 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
63 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
71 if (a.getElementType() == f16x2Ty) {
75 if (a.getElementType() == i32x2Ty) {
80 if (a.getElementType() == f64x2Ty) {
83 if (a.getElementType() == f32x2Ty) {
92 return vectorResultType;
104 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
115 auto makeConst = [&](int32_t index) ->
Value {
125 if (arrayType.getElementType() == f16x2Ty ||
126 arrayType.getElementType() == f32x1Ty) {
127 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
129 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
131 loc, arrayType.getElementType(), el);
132 elements.push_back(el);
140 if (arrayType.getElementType() == i32x2Ty ||
141 arrayType.getElementType() == f64x2Ty ||
142 arrayType.getElementType() == f32x2Ty) {
144 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
146 rewriter.
create<LLVM::UndefOp>(loc, arrayType.getElementType());
148 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149 Value x2 = rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult,
151 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
153 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
155 elements.push_back(vec);
160 Value result = rewriter.
create<LLVM::UndefOp>(loc, arrayType);
162 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, el.value(),
168 return intrinsicResult;
178 NVVM::MMATypes operandPtxType) {
187 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
189 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190 Value toUse = b.
create<LLVM::ExtractValueOp>(operand, i);
194 if (arrayTy.getElementType() == i8x4Ty ||
195 arrayTy.getElementType() == i4x8Ty ||
196 (arrayTy.getElementType() == f32x1Ty &&
197 operandPtxType == NVVM::MMATypes::tf32)) {
198 result.push_back(b.
create<LLVM::BitcastOp>(i32Ty, toUse));
205 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207 innerArrayTy.getElementType() == f64Ty ||
208 innerArrayTy.getElementType() == f32Ty)) {
209 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210 idx < innerSize; idx++) {
211 result.push_back(b.
create<LLVM::ExtractElementOp>(
217 result.push_back(toUse);
224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225 barrierType.getMemorySpace()));
230 nvgpu::MBarrierGroupType barrierType) {
235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
243 nvgpu::MBarrierGroupType barrierType) {
245 MemRefLayoutAttrInterface layout;
256 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
267 auto vectorResultType = dyn_cast<VectorType>(op->
getResultTypes()[0]);
268 if (!vectorResultType) {
272 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
274 int64_t num32BitRegs = vectorResultType.getDimSize(0);
276 Type ldMatrixResultType;
277 if (num32BitRegs > 1) {
284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
286 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
287 adaptor.getIndices(), rewriter);
288 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289 ldMatrixResultType, srcPtr,
291 op.getTranspose() ? NVVM::MMALayout::col
292 : NVVM::MMALayout::row);
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = b.create<LLVM::UndefOp>(finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
302 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
304 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305 result = b.create<LLVM::InsertValueOp>(result, casted, i);
315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(
Type t) {
318 return NVVM::MMATypes::s8;
320 return NVVM::MMATypes::s4;
322 return NVVM::MMATypes::f16;
324 return NVVM::MMATypes::f64;
326 return NVVM::MMATypes::tf32;
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
346 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->
emitOpError(
"failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->
emitOpError(
"failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
361 "could not infer the PTX type for the accumulator/result");
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
378 Value intrinsicResult = b.create<NVVM::MmaOp>(
379 intrinsicResTy, matA, matB, matC,
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
386 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387 NVVM::MMALayout::col});
389 desiredRetTy, intrinsicResult,
395 struct ConvertNVGPUToNVVMPass
396 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
400 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401 arith::ArithDialect>();
404 void runOnOperation()
override {
410 converter, [](gpu::AddressSpace space) ->
unsigned {
412 case gpu::AddressSpace::Global:
413 return static_cast<unsigned>(
415 case gpu::AddressSpace::Workgroup:
416 return static_cast<unsigned>(
418 case gpu::AddressSpace::Private:
421 llvm_unreachable(
"unknown address space enum value");
427 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
430 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
431 Type elemType = type.getFragmented().getElementType();
432 int64_t sizeM = type.getFragmented().getDimSize(0);
433 int64_t sizeN = type.getFragmented().getDimSize(1);
437 numMembers = sizeN / 2;
438 else if (elemType.
isF16())
439 numMembers = sizeN / 4;
441 llvm_unreachable(
"unsupported type for warpgroup accumulator");
444 for (
unsigned i = 0; i < numMembers; i++)
445 innerStructBody.push_back(elemType);
446 auto innerStructType =
451 structBody.push_back(innerStructType);
455 return converter.convertType(convertedType);
457 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
460 converter.addConversion(
461 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
464 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
465 return converter.convertType(
468 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
473 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474 target.addLegalDialect<::mlir::arith::ArithDialect>();
475 target.addLegalDialect<::mlir::memref::MemRefDialect>();
476 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
478 converter, patterns, target);
480 std::move(patterns))))
486 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
490 llvm::raw_string_ostream ss(str);
491 for (
unsigned i = 0; i < matCSize; i++)
493 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
506 static std::string buildMmaSparseAsmString(
507 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
508 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
509 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
510 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
511 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
512 return NVVM::stringifyMMATypes(ptxType);
516 llvm::raw_string_ostream ss(asmStr);
517 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
518 << shape[2] <<
".row.col.";
521 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
523 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
524 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
525 unsigned asmArgIdx = 0;
529 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
531 for (
unsigned i = 0; i < arrSize; i++)
532 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
535 ss <<
"$" << asmArgIdx++ <<
",";
536 assert(metaDataSelector <= 1);
537 ss <<
"0x" << metaDataSelector <<
";";
544 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
546 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
547 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
549 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
550 Type intrinsicResultType) {
551 auto asmDialectAttr =
554 const unsigned matASize = unpackedAData.size();
555 const unsigned matBSize = unpackedB.size();
556 const unsigned matCSize = unpackedC.size();
558 std::string asmStr = buildMmaSparseAsmString(
559 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
560 ptxTypeD, overflow, metadataSelector);
561 std::string constraintStr =
562 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
565 asmVals.reserve(matASize + matBSize + matCSize + 1);
567 llvm::append_range(asmVals, args);
568 asmVals.push_back(indexData);
570 return b.
create<LLVM::InlineAsmOp>(
582 struct NVGPUMmaSparseSyncLowering
587 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
592 VectorType aType = op.getMatrixA().getType();
593 VectorType bType = op.getMatrixB().getType();
594 VectorType cType = op.getMatrixC().getType();
596 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
597 if (failed(ptxTypeA))
598 return op->
emitOpError(
"failed to deduce operand PTX types");
599 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
600 if (failed(ptxTypeB))
601 return op->
emitOpError(
"failed to deduce operand PTX types");
602 std::optional<NVVM::MMATypes> ptxTypeC =
603 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
607 "could not infer the PTX type for the accumulator/result");
610 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
611 if (aType.getElementType().isF32() && !tf32Enabled)
615 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
616 if (isa<IntegerType>(aType.getElementType()))
617 overflow = NVVM::MMAIntOverflow::satfinite;
631 Value sparseMetadata = adaptor.getSparseMetadata();
632 if (sparseMetadata.
getType() !=
634 return op->
emitOpError() <<
"Expected metadata type to be LLVM "
635 "VectorType of 2 i16 elements";
639 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
640 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
641 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
643 if (failed(intrinsicResult))
646 assert((*intrinsicResult).getNumResults() == 1 &&
647 "expected inline asm op returns a single LLVM struct type");
650 (*intrinsicResult)->getResult(0), rewriter));
655 struct NVGPUAsyncCopyLowering
661 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
665 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
667 getStridedElementPtr(b.
getLoc(), dstMemrefType, adaptor.getDst(),
668 adaptor.getDstIndices(), rewriter);
669 FailureOr<unsigned> dstAddressSpace =
670 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
671 if (failed(dstAddressSpace))
673 loc,
"destination memref address space not convertible to integer");
675 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
676 FailureOr<unsigned> srcAddressSpace =
677 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
678 if (failed(srcAddressSpace))
680 loc,
"source memref address space not convertible to integer");
682 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
683 adaptor.getSrcIndices(), rewriter);
687 scrPtr = b.
create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
688 int64_t dstElements = adaptor.getDstElements().getZExtValue();
689 int64_t sizeInBytes =
690 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
695 Value srcBytes = adaptor.getSrcElements();
707 srcBytes = b.
create<LLVM::LShrOp>(
708 b.
create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
712 NVVM::LoadCacheModifierKind cacheModifier =
713 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
714 ? NVVM::LoadCacheModifierKind::CG
715 : NVVM::LoadCacheModifierKind::CA;
717 b.
create<NVVM::CpAsyncOp>(
730 struct NVGPUAsyncCreateGroupLowering
736 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
738 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.
getLoc());
748 struct NVGPUAsyncWaitLowering
754 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
757 int32_t numGroups = adaptor.getNumGroups().value_or(0);
758 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.
getLoc(), numGroups);
765 struct NVGPUMBarrierCreateLowering
769 template <
typename moduleT>
772 MemRefType barrierType)
const {
776 auto global = rewriter.
create<memref::GlobalOp>(
777 funcOp->
getLoc(),
"__mbarrier",
783 symbolTable.insert(global);
788 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
792 rewriter.
getContext(), op.getBarriers().getType());
794 memref::GlobalOp global;
796 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
798 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
808 template <
typename SourceOp>
814 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
817 MemRefType mbarrierMemrefType =
820 b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
825 struct NVGPUMBarrierInitLowering
826 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
827 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
830 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
833 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
835 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
836 adaptor.getMbarId(), rewriter);
840 op, barrier, count, adaptor.getPredicate());
843 adaptor.getPredicate());
850 struct NVGPUMBarrierArriveLowering
851 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
852 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
854 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
858 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
859 adaptor.getMbarId(), rewriter);
860 Type tokenType = getTypeConverter()->convertType(
875 struct NVGPUMBarrierArriveNoCompleteLowering
876 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
877 using MBarrierBasePattern<
878 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
880 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
884 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
885 adaptor.getMbarId(), rewriter);
886 Type tokenType = getTypeConverter()->convertType(
891 op, tokenType, barrier, count);
894 op, tokenType, barrier, count);
901 struct NVGPUMBarrierTestWaitLowering
902 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
903 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
905 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
909 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
910 adaptor.getMbarId(), rewriter);
914 op, retType, barrier, adaptor.getToken());
917 op, retType, barrier, adaptor.getToken());
923 struct NVGPUMBarrierArriveExpectTxLowering
924 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
925 using MBarrierBasePattern<
926 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
928 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
932 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
933 adaptor.getMbarId(), rewriter);
938 op, barrier, txcount, adaptor.getPredicate());
943 op, barrier, txcount, adaptor.getPredicate());
948 struct NVGPUMBarrierTryWaitParityLowering
949 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
950 using MBarrierBasePattern<
951 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
953 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
957 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
958 adaptor.getMbarId(), rewriter);
965 op, barrier, phase, ticks);
975 struct NVGPUTmaAsyncLoadOpLowering
976 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
977 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
979 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
982 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
983 Value dest = getStridedElementPtr(op->
getLoc(), srcMemrefType,
984 adaptor.getDst(), {}, rewriter);
986 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
987 adaptor.getMbarId(), rewriter);
994 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
996 adaptor.getPredicate());
1001 struct NVGPUTmaAsyncStoreOpLowering
1002 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1003 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1005 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1008 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1009 Value dest = getStridedElementPtr(op->
getLoc(), srcMemrefType,
1010 adaptor.getSrc(), {}, rewriter);
1017 op, adaptor.getTensorMapDescriptor(), dest, coords,
1018 adaptor.getPredicate());
1023 struct NVGPUGenerateWarpgroupDescriptorLowering
1029 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1034 nvgpu::TensorMapSwizzleKind swizzleKind =
1035 op.getTensorMap().getType().getSwizzle();
1038 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1039 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1040 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1043 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1044 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1045 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1049 auto makeConst = [&](uint64_t index) ->
Value {
1052 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1053 return b.
create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1055 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1056 return b.
create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1058 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1059 return b.
create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1062 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1063 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1064 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1065 uint64_t offsetVal = 0;
1067 Value strideDim = makeConst(strideDimVal);
1068 Value leadDim = makeConst(leadDimVal);
1070 Value baseAddr = getStridedElementPtr(
1071 op->
getLoc(), cast<MemRefType>(op.getTensor().getType()),
1072 adaptor.getTensor(), {}, rewriter);
1073 Value basePtr = b.
create<LLVM::PtrToIntOp>(ti64, baseAddr);
1075 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1077 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1078 startLeadBit = 16, startBaseAddrBit = 0;
1079 Value dsc = makeConst(0);
1081 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1083 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1085 dsc = insertBit(dsc, strideDim, startStrideBit);
1087 dsc = insertBit(dsc, leadDim, startLeadBit);
1089 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1091 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1092 <<
"leading_off:" << leadDimVal <<
"\t"
1093 <<
"stride_off :" << strideDimVal <<
"\t"
1094 <<
"base_offset:" << offsetVal <<
"\t"
1095 <<
"layout_type:" << swizzle <<
" ("
1096 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1097 <<
")\n start_addr : " << baseAddr <<
"\n");
1113 enum CUtensorMapDataTypeEnum {
1114 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1115 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1116 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1117 CU_TENSOR_MAP_DATA_TYPE_INT32,
1118 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1119 CU_TENSOR_MAP_DATA_TYPE_INT64,
1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1121 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1122 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1123 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1124 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1125 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1126 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1130 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1132 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1134 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1136 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1138 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1140 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1142 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1144 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1146 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1148 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1150 llvm_unreachable(
"Not supported data type");
1153 struct NVGPUTmaCreateDescriptorOpLowering
1158 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1164 Value tensorElementType =
1165 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1166 auto promotedOperands = getTypeConverter()->promoteOperands(
1169 Value boxArrayPtr = b.
create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1170 makeI64Const(b, 5));
1171 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1172 Value gep = b.
create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1173 boxArrayPtr, makeI64Const(b, index));
1174 b.
create<LLVM::StoreOp>(value, gep);
1177 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1180 arguments.push_back(promotedOperands[0]);
1181 arguments.push_back(promotedOperands[1]);
1182 arguments.push_back(tensorElementType);
1183 arguments.push_back(
1184 makeI64Const(b, (
int)desc.getInterleave()));
1185 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1186 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1187 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1188 arguments.push_back(boxArrayPtr);
1202 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1204 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1211 struct NVGPUWarpgroupMmaOpLowering
1235 class WarpgroupGemm {
1236 nvgpu::WarpgroupMmaOp op;
1241 int64_t totalM, totalN, totalK;
1244 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1247 int iterationM = 0, iterationN = 0, iterationK = 0;
1252 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1255 if (inputElemType.
isTF32()) {
1257 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1262 }
else if (inputElemType.
isInteger(1)) {
1265 llvm_unreachable(
"msg: not supported K shape");
1267 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1268 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1272 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1273 bool useF32 =
false)
const {
1274 auto getWgmmaType = [=](
Type elemType) {
1276 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1277 if (elemType.
isF16())
1278 return NVVM::WGMMATypes::f16;
1280 return NVVM::WGMMATypes::bf16;
1282 return NVVM::WGMMATypes::e4m3;
1284 return NVVM::WGMMATypes::e5m2;
1286 return NVVM::WGMMATypes::b1;
1288 return NVVM::WGMMATypes::s8;
1290 return NVVM::WGMMATypes::u8;
1292 return NVVM::WGMMATypes::s32;
1293 llvm_unreachable(
"unsupported type");
1300 generateWgmmaLayout(std::optional<bool>
transpose)
const {
1307 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1312 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1314 NVVM::WGMMAScaleOut::one);
1317 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1319 NVVM::WGMMAScaleIn::one);
1345 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1346 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1347 Type elemA = matrixTypeA.getElementType();
1349 int tileShapeA = matrixTypeA.getDimSize(1);
1350 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1352 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1353 <<
"] [wgmma descriptors] Descriptor A + "
1354 << incrementVal <<
" | \t ");
1357 return makeAdd(desc, makeI64Const(b, incrementVal));
1371 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1372 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1373 Type elemB = matrixTypeB.getElementType();
1375 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1377 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1380 return makeAdd(desc, makeI64Const(b, incrementVal));
1385 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1386 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1387 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1388 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1389 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1390 << (iterationK * wgmmaK) <<
":"
1391 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1392 <<
" B[" << (iterationK * wgmmaK) <<
":"
1393 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1394 << wgmmaN <<
"])\n");
1396 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1397 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1399 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1400 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1402 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1403 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1405 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1406 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1408 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1409 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1410 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1411 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1412 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1415 op->
getContext(), NVVM::MMAIntOverflow::wrapped);
1417 return b.
create<NVVM::WgmmaMmaAsyncOp>(
1418 matrixC.
getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1419 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1424 Value generateWgmmaGroup() {
1426 b.
create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1430 for (
int i = 0; i < iterationM; ++i) {
1431 Value matrixC = b.
create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1432 for (
int j = 0;
j < iterationN; ++
j)
1433 for (
int k = 0; k < iterationK; ++k)
1434 matrixC = generateWgmma(i,
j, k, matrixC);
1435 wgmmaResults.push_back(matrixC);
1438 wgmmaResult = b.
create<LLVM::InsertValueOp>(wgmmaResult.
getType(),
1439 wgmmaResult, matrix, idx);
1447 : op(op), b(b), adaptor(adaptor) {
1449 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1450 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1451 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1452 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1453 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1454 << totalK <<
"][" << totalN <<
"] ---===\n");
1459 op.getDescriptorA().getType().getTensor().getElementType());
1462 iterationM = totalM / wgmmaM;
1463 iterationN = totalN / wgmmaN;
1464 iterationK = totalK / wgmmaK;
1472 Value generateWarpgroupMma() {
1473 b.
create<NVVM::WgmmaFenceAlignedOp>();
1474 Value wgmmaResult = generateWgmmaGroup();
1475 b.
create<NVVM::WgmmaGroupSyncAlignedOp>();
1476 b.
create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1481 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1486 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1489 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1497 struct NVGPUWarpgroupMmaStoreOpLowering
1543 auto makeConst = [&](int32_t index) ->
Value {
1546 Value c1 = makeConst(1);
1547 Value c2 = makeConst(2);
1548 Value c4 = makeConst(4);
1549 Value c8 = makeConst(8);
1550 Value c16 = makeConst(16);
1560 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1565 Value idy1 = b.
create<arith::IndexCastOp>(it, makeAdd(y, c1));
1566 Value d0 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i);
1567 Value d1 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1573 Value laneId = b.
create<LLVM::URemOp>(i32, tidx, warpSize);
1574 Value warpId = b.
create<LLVM::UDivOp>(i32, tidx, warpSize);
1575 Value lane4Id = b.
create<LLVM::UDivOp>(i32, laneId, c4);
1576 Value lane4modId = b.
create<LLVM::URemOp>(i32, laneId, c4);
1578 Value tj = makeMul(lane4modId, c2);
1579 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1581 ti = makeAdd(ti, makeConst(offset));
1583 auto structType = cast<LLVM::LLVMStructType>(matrixD.
getType());
1586 constexpr
unsigned numAdjacentRegisters = 2;
1588 constexpr
unsigned numStackedMatrices = 2;
1590 size_t storeCount = (structType.getBody().size() /
1591 (numStackedMatrices * numAdjacentRegisters));
1593 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1594 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1595 for (
size_t j = 0;
j < storeCount; ++
j) {
1596 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1597 size_t structIndex = (i * numAdjacentRegisters) +
1598 (
j * (numStackedMatrices * numAdjacentRegisters));
1599 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1605 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1609 Value matriDValue = adaptor.getMatrixD();
1610 auto stype = cast<LLVM::LLVMStructType>(matriDValue.
getType());
1612 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1613 Value innerStructValue = b.
create<LLVM::ExtractValueOp>(matriDValue, idx);
1614 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1615 offset += structType.getBody().size();
1622 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1627 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1631 getTypeConverter()->convertType(op.getMatrixC().getType()));
1632 Type elemType = cast<LLVM::LLVMStructType>(packStructType.
getBody().front())
1636 Value packStruct = b.
create<LLVM::UndefOp>(packStructType);
1640 auto structType = cast<LLVM::LLVMStructType>(s);
1641 Value structValue = b.
create<LLVM::ExtractValueOp>(packStruct, idx);
1642 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1643 structValue = b.
create<LLVM::InsertValueOp>(
1646 innerStructs.push_back(structValue);
1650 packStruct = b.
create<LLVM::InsertValueOp>(packStruct.getType(),
1651 packStruct, matrix, idx);
1658 struct NVGPUTmaPrefetchOpLowering
1662 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1665 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1673 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1678 VectorType inTy = op.getIn().getType();
1680 auto convert1DVec = [&](
Type llvm1DVectorTy,
Value inVec) {
1681 Value ret1DVec = b.
create<LLVM::UndefOp>(llvm1DVectorTy);
1682 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1683 for (
int i = 0; i < numElems; i++) {
1685 Value elem = b.
create<LLVM::ExtractElementOp>(inVec, idx);
1686 Value dst = b.
create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1687 ret1DVec = b.
create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1691 if (inTy.getRank() == 1) {
1692 rewriter.
replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1696 op.getOperation(), adaptor.
getOperands(), *(this->getTypeConverter()),
1698 OpAdaptor adaptor(operands);
1699 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1709 NVGPUMBarrierCreateLowering,
1710 NVGPUMBarrierInitLowering,
1711 NVGPUMBarrierArriveLowering,
1712 NVGPUMBarrierArriveNoCompleteLowering,
1713 NVGPUMBarrierTestWaitLowering,
1714 NVGPUMBarrierTryWaitParityLowering,
1715 NVGPUTmaAsyncLoadOpLowering,
1716 NVGPUTmaAsyncStoreOpLowering,
1717 NVGPUTmaCreateDescriptorOpLowering,
1718 NVGPUTmaPrefetchOpLowering,
1719 NVGPUMBarrierArriveExpectTxLowering,
1720 NVGPUGenerateWarpgroupDescriptorLowering,
1721 NVGPUWarpgroupMmaOpLowering,
1722 NVGPUWarpgroupMmaStoreOpLowering,
1723 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1724 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1725 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1726 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isFloat8E4M3FN() const
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isFloat8E5M2() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.