28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
33 #define DEBUG_TYPE "nvgpu-to-nvvm"
34 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35 #define DBGSE() (llvm::dbgs())
38 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39 #include "mlir/Conversion/Passes.h.inc"
52 assert(llvm::isa<IntegerType>(type) &&
"expected an integer Value");
62 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
70 if (a.getElementType() == f16x2Ty) {
74 if (a.getElementType() == i32x2Ty) {
79 if (a.getElementType() == f64x2Ty) {
82 if (a.getElementType() == f32x2Ty) {
91 return vectorResultType;
103 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
104 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
114 auto makeConst = [&](int32_t index) ->
Value {
124 if (arrayType.getElementType() == f16x2Ty ||
125 arrayType.getElementType() == f32x1Ty) {
126 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
128 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130 loc, arrayType.getElementType(), el);
131 elements.push_back(el);
139 if (arrayType.getElementType() == i32x2Ty ||
140 arrayType.getElementType() == f64x2Ty ||
141 arrayType.getElementType() == f32x2Ty) {
143 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145 rewriter.
create<LLVM::UndefOp>(loc, arrayType.getElementType());
147 rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
148 Value x2 = rewriter.
create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
152 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.
getType(), vec,
154 elements.push_back(vec);
159 Value result = rewriter.
create<LLVM::UndefOp>(loc, arrayType);
161 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, el.value(),
167 return intrinsicResult;
177 NVVM::MMATypes operandPtxType) {
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.
getType());
188 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = b.
create<LLVM::ExtractValueOp>(operand, i);
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == f32x1Ty &&
196 operandPtxType == NVVM::MMATypes::tf32)) {
197 result.push_back(b.
create<LLVM::BitcastOp>(i32Ty, toUse));
204 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206 innerArrayTy.getElementType() == f64Ty ||
207 innerArrayTy.getElementType() == f32Ty)) {
208 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209 idx < innerSize; idx++) {
210 result.push_back(b.
create<LLVM::ExtractElementOp>(
216 result.push_back(toUse);
223 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224 barrierType.getMemorySpace()));
229 nvgpu::MBarrierGroupType barrierType) {
234 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
242 nvgpu::MBarrierGroupType barrierType) {
244 MemRefLayoutAttrInterface layout;
255 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
266 auto vectorResultType = dyn_cast<VectorType>(op->
getResultTypes()[0]);
267 if (!vectorResultType) {
271 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
273 int64_t num32BitRegs = vectorResultType.getDimSize(0);
275 Type ldMatrixResultType;
276 if (num32BitRegs > 1) {
283 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
286 adaptor.getIndices(), rewriter);
287 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
288 ldMatrixResultType, srcPtr,
290 op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row);
297 Type finalResultType = typeConverter->convertType(vectorResultType);
298 Value result = b.create<LLVM::UndefOp>(finalResultType);
299 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
304 result = b.create<LLVM::InsertValueOp>(result, casted, i);
317 return NVVM::MMATypes::s8;
319 return NVVM::MMATypes::s4;
321 return NVVM::MMATypes::f16;
323 return NVVM::MMATypes::f64;
325 return NVVM::MMATypes::tf32;
333 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
338 VectorType aType = op.getMatrixA().getType();
339 VectorType bType = op.getMatrixA().getType();
340 VectorType cType = op.getMatrixC().getType();
342 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
345 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
346 if (aType.getElementType().isF32() && !tf32Enabled)
351 return op->
emitOpError(
"failed to deduce operand PTX types");
354 return op->
emitOpError(
"failed to deduce operand PTX types");
355 std::optional<NVVM::MMATypes> ptxTypeC =
356 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
360 "could not infer the PTX type for the accumulator/result");
363 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364 if (isa<IntegerType>(aType.getElementType()))
365 overflow = NVVM::MMAIntOverflow::satfinite;
377 Value intrinsicResult = b.create<NVVM::MmaOp>(
378 intrinsicResTy, matA, matB, matC,
383 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
386 NVVM::MMALayout::col});
388 desiredRetTy, intrinsicResult,
394 struct ConvertNVGPUToNVVMPass
395 :
public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
399 registry.
insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400 arith::ArithDialect>();
403 void runOnOperation()
override {
409 converter, [](gpu::AddressSpace space) ->
unsigned {
411 case gpu::AddressSpace::Global:
412 return static_cast<unsigned>(
414 case gpu::AddressSpace::Workgroup:
415 return static_cast<unsigned>(
417 case gpu::AddressSpace::Private:
420 llvm_unreachable(
"unknown address space enum value");
426 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) ->
Type {
429 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
430 Type elemType = type.getFragmented().getElementType();
431 int64_t sizeM = type.getFragmented().getDimSize(0);
432 int64_t sizeN = type.getFragmented().getDimSize(1);
436 numMembers = sizeN / 2;
437 else if (elemType.
isF16())
438 numMembers = sizeN / 4;
440 llvm_unreachable(
"unsupported type for warpgroup accumulator");
443 for (
unsigned i = 0; i < numMembers; i++)
444 innerStructBody.push_back(elemType);
445 auto innerStructType =
450 structBody.push_back(innerStructType);
454 return converter.convertType(convertedType);
456 converter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
459 converter.addConversion(
460 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
463 converter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
464 return converter.convertType(
467 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) ->
Type {
472 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473 target.addLegalDialect<::mlir::arith::ArithDialect>();
474 target.addLegalDialect<::mlir::memref::MemRefDialect>();
475 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
477 converter, patterns, target);
479 std::move(patterns))))
485 static std::string buildMmaSparseAsmConstraintString(
unsigned matASize,
489 llvm::raw_string_ostream ss(str);
490 for (
unsigned i = 0; i < matCSize; i++)
492 for (
unsigned i = 0; i < matASize + matBSize + matCSize; i++)
505 static std::string buildMmaSparseAsmString(
506 const std::array<int64_t, 3> &shape,
unsigned matASize,
unsigned matBSize,
507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509 std::optional<NVVM::MMAIntOverflow> overflow,
unsigned metaDataSelector) {
510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511 return NVVM::stringifyMMATypes(ptxType);
515 llvm::raw_string_ostream ss(asmStr);
516 ss <<
"mma.sp.sync.aligned.m" << shape[0] <<
"n" << shape[1] <<
"k"
517 << shape[2] <<
".row.col.";
520 ss << NVVM::stringifyMMAIntOverflow(*overflow) <<
".";
522 ss << ptxTypeStr(ptxTypeD) <<
"." << ptxTypeStr(ptxTypeA) <<
"."
523 << ptxTypeStr(ptxTypeB) <<
"." << ptxTypeStr(ptxTypeC) <<
" ";
524 unsigned asmArgIdx = 0;
528 for (
const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
530 for (
unsigned i = 0; i < arrSize; i++)
531 ss <<
"$" << asmArgIdx++ << (i < arrSize - 1 ?
"," :
"");
534 ss <<
"$" << asmArgIdx++ <<
",";
535 assert(metaDataSelector <= 1);
536 ss <<
"0x" << metaDataSelector <<
";";
545 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
546 std::optional<NVVM::MMAIntOverflow> overflow,
ArrayRef<Value> unpackedAData,
548 int64_t metadataSelector,
const std::array<int64_t, 3> &shape,
549 Type intrinsicResultType) {
550 auto asmDialectAttr =
553 const unsigned matASize = unpackedAData.size();
554 const unsigned matBSize = unpackedB.size();
555 const unsigned matCSize = unpackedC.size();
557 std::string asmStr = buildMmaSparseAsmString(
558 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
559 ptxTypeD, overflow, metadataSelector);
560 std::string constraintStr =
561 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
564 asmVals.reserve(matASize + matBSize + matCSize + 1);
566 llvm::append_range(asmVals, args);
567 asmVals.push_back(indexData);
569 return b.
create<LLVM::InlineAsmOp>(
581 struct NVGPUMmaSparseSyncLowering
586 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
591 VectorType aType = op.getMatrixA().getType();
592 VectorType bType = op.getMatrixB().getType();
593 VectorType cType = op.getMatrixC().getType();
597 return op->
emitOpError(
"failed to deduce operand PTX types");
600 return op->
emitOpError(
"failed to deduce operand PTX types");
601 std::optional<NVVM::MMATypes> ptxTypeC =
602 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
606 "could not infer the PTX type for the accumulator/result");
609 bool tf32Enabled = op->
hasAttr(op.getTf32EnabledAttrName());
610 if (aType.getElementType().isF32() && !tf32Enabled)
614 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
615 if (isa<IntegerType>(aType.getElementType()))
616 overflow = NVVM::MMAIntOverflow::satfinite;
630 Value sparseMetadata = adaptor.getSparseMetadata();
631 if (sparseMetadata.
getType() !=
633 return op->
emitOpError() <<
"Expected metadata type to be LLVM "
634 "VectorType of 2 i16 elements";
639 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
640 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
642 if (
failed(intrinsicResult))
645 assert((*intrinsicResult).getNumResults() == 1 &&
646 "expected inline asm op returns a single LLVM struct type");
649 (*intrinsicResult)->getResult(0), rewriter));
654 struct NVGPUAsyncCopyLowering
660 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
664 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
666 getStridedElementPtr(b.
getLoc(), dstMemrefType, adaptor.getDst(),
667 adaptor.getDstIndices(), rewriter);
669 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
670 if (
failed(dstAddressSpace))
672 loc,
"destination memref address space not convertible to integer");
674 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
676 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
677 if (
failed(srcAddressSpace))
679 loc,
"source memref address space not convertible to integer");
681 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
682 adaptor.getSrcIndices(), rewriter);
686 scrPtr = b.
create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
687 int64_t dstElements = adaptor.getDstElements().getZExtValue();
688 int64_t sizeInBytes =
689 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
694 Value srcBytes = adaptor.getSrcElements();
706 srcBytes = b.
create<LLVM::LShrOp>(
707 b.
create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
711 NVVM::LoadCacheModifierKind cacheModifier =
712 (op.getBypassL1().value_or(
false) && sizeInBytes == 16)
713 ? NVVM::LoadCacheModifierKind::CG
714 : NVVM::LoadCacheModifierKind::CA;
716 b.
create<NVVM::CpAsyncOp>(
729 struct NVGPUAsyncCreateGroupLowering
735 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
737 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.
getLoc());
747 struct NVGPUAsyncWaitLowering
753 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
756 int32_t numGroups = adaptor.getNumGroups().value_or(0);
757 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.
getLoc(), numGroups);
764 struct NVGPUMBarrierCreateLowering
768 template <
typename moduleT>
771 MemRefType barrierType)
const {
775 auto global = rewriter.
create<memref::GlobalOp>(
776 funcOp->
getLoc(),
"__mbarrier",
782 symbolTable.insert(global);
787 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
791 rewriter.
getContext(), op.getBarriers().getType());
793 memref::GlobalOp global;
795 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
807 template <
typename SourceOp>
813 nvgpu::MBarrierGroupType mbarType,
Value memrefDesc,
816 MemRefType mbarrierMemrefType =
819 b.
getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
824 struct NVGPUMBarrierInitLowering
825 :
public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
826 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
829 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
832 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
834 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
835 adaptor.getMbarId(), rewriter);
839 op, barrier, count, adaptor.getPredicate());
842 adaptor.getPredicate());
849 struct NVGPUMBarrierArriveLowering
850 :
public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
851 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
853 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
857 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
858 adaptor.getMbarId(), rewriter);
859 Type tokenType = getTypeConverter()->convertType(
874 struct NVGPUMBarrierArriveNoCompleteLowering
875 :
public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
876 using MBarrierBasePattern<
877 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
879 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
883 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
884 adaptor.getMbarId(), rewriter);
885 Type tokenType = getTypeConverter()->convertType(
890 op, tokenType, barrier, count);
893 op, tokenType, barrier, count);
900 struct NVGPUMBarrierTestWaitLowering
901 :
public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
902 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
904 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
908 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
909 adaptor.getMbarId(), rewriter);
913 op, retType, barrier, adaptor.getToken());
916 op, retType, barrier, adaptor.getToken());
922 struct NVGPUMBarrierArriveExpectTxLowering
923 :
public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
924 using MBarrierBasePattern<
925 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
927 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
931 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
932 adaptor.getMbarId(), rewriter);
937 op, barrier, txcount, adaptor.getPredicate());
942 op, barrier, txcount, adaptor.getPredicate());
947 struct NVGPUMBarrierTryWaitParityLowering
948 :
public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
949 using MBarrierBasePattern<
950 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
952 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
956 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
957 adaptor.getMbarId(), rewriter);
964 op, barrier, phase, ticks);
974 struct NVGPUTmaAsyncLoadOpLowering
975 :
public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
976 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
978 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
981 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
982 Value dest = getStridedElementPtr(op->
getLoc(), srcMemrefType,
983 adaptor.getDst(), {}, rewriter);
985 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
986 adaptor.getMbarId(), rewriter);
993 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
995 adaptor.getPredicate());
1000 struct NVGPUTmaAsyncStoreOpLowering
1001 :
public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1002 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1004 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1007 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1008 Value dest = getStridedElementPtr(op->
getLoc(), srcMemrefType,
1009 adaptor.getSrc(), {}, rewriter);
1016 op, adaptor.getTensorMapDescriptor(), dest, coords,
1017 adaptor.getPredicate());
1022 struct NVGPUGenerateWarpgroupDescriptorLowering
1028 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1033 nvgpu::TensorMapSwizzleKind swizzleKind =
1034 op.getTensorMap().getType().getSwizzle();
1037 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1038 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1039 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1042 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1043 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1044 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1048 auto makeConst = [&](uint64_t index) ->
Value {
1051 auto shiftLeft = [&](
Value value,
unsigned shift) ->
Value {
1052 return b.
create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1054 auto shiftRight = [&](
Value value,
unsigned shift) ->
Value {
1055 return b.
create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1057 auto insertBit = [&](
Value desc,
Value val,
int startBit) {
1058 return b.
create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1061 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1062 uint64_t strideDimVal = (layout << 3) >>
exclude4LSB;
1063 uint64_t leadDimVal = (sizeN * layout) >>
exclude4LSB;
1064 uint64_t offsetVal = 0;
1066 Value strideDim = makeConst(strideDimVal);
1067 Value leadDim = makeConst(leadDimVal);
1069 Value baseAddr = getStridedElementPtr(
1070 op->
getLoc(), cast<MemRefType>(op.getTensor().getType()),
1071 adaptor.getTensor(), {}, rewriter);
1072 Value basePtr = b.
create<LLVM::PtrToIntOp>(ti64, baseAddr);
1074 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1076 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1077 startLeadBit = 16, startBaseAddrBit = 0;
1078 Value dsc = makeConst(0);
1080 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1082 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1084 dsc = insertBit(dsc, strideDim, startStrideBit);
1086 dsc = insertBit(dsc, leadDim, startLeadBit);
1088 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1090 LLVM_DEBUG(
DBGS() <<
"Generating warpgroup.descriptor: "
1091 <<
"leading_off:" << leadDimVal <<
"\t"
1092 <<
"stride_off :" << strideDimVal <<
"\t"
1093 <<
"base_offset:" << offsetVal <<
"\t"
1094 <<
"layout_type:" << swizzle <<
" ("
1095 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1096 <<
")\n start_addr : " << baseAddr <<
"\n");
1112 enum CUtensorMapDataTypeEnum {
1113 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1114 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1115 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1116 CU_TENSOR_MAP_DATA_TYPE_INT32,
1117 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1118 CU_TENSOR_MAP_DATA_TYPE_INT64,
1119 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1121 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1122 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1123 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1124 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1125 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1129 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
1131 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
1133 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
1135 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
1137 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
1139 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
1141 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1143 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1145 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1147 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1149 llvm_unreachable(
"Not supported data type");
1152 struct NVGPUTmaCreateDescriptorOpLowering
1157 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1163 Value tensorElementType =
1164 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1165 auto promotedOperands = getTypeConverter()->promoteOperands(
1168 Value boxArrayPtr = b.
create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1169 makeI64Const(b, 5));
1170 for (
auto [index, value] :
llvm::enumerate(adaptor.getBoxDimensions())) {
1171 Value gep = b.
create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1172 boxArrayPtr, makeI64Const(b, index));
1173 b.
create<LLVM::StoreOp>(value, gep);
1176 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1179 arguments.push_back(promotedOperands[0]);
1180 arguments.push_back(promotedOperands[1]);
1181 arguments.push_back(tensorElementType);
1182 arguments.push_back(
1183 makeI64Const(b, (
int)desc.getInterleave()));
1184 arguments.push_back(makeI64Const(b, (
int)desc.getSwizzle()));
1185 arguments.push_back(makeI64Const(b, (
int)desc.getL2promo()));
1186 arguments.push_back(makeI64Const(b, (
int)desc.getOob()));
1187 arguments.push_back(boxArrayPtr);
1201 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1203 hostRegisterCallBuilder.
create(b.
getLoc(), b, arguments).getResult();
1210 struct NVGPUWarpgroupMmaOpLowering
1234 class WarpgroupGemm {
1235 nvgpu::WarpgroupMmaOp op;
1240 int64_t totalM, totalN, totalK;
1243 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1246 int iterationM = 0, iterationN = 0, iterationK = 0;
1251 void findWgmmaShape(int64_t sizeM, int64_t sizeN,
Type inputElemType) {
1254 if (inputElemType.
isTF32()) {
1256 }
else if (inputElemType.
isF16() || inputElemType.
isBF16()) {
1261 }
else if (inputElemType.
isInteger(1)) {
1264 llvm_unreachable(
"msg: not supported K shape");
1266 LLVM_DEBUG(
DBGS() <<
"Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1267 <<
", n = " << wgmmaN <<
", k = " << wgmmaK <<
"]\n");
1271 NVVM::WGMMATypesAttr generateWgmmaType(
Type type,
1272 bool useF32 =
false)
const {
1273 auto getWgmmaType = [=](
Type elemType) {
1275 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1276 if (elemType.
isF16())
1277 return NVVM::WGMMATypes::f16;
1279 return NVVM::WGMMATypes::bf16;
1281 return NVVM::WGMMATypes::e4m3;
1283 return NVVM::WGMMATypes::e5m2;
1285 return NVVM::WGMMATypes::b1;
1287 return NVVM::WGMMATypes::s8;
1289 return NVVM::WGMMATypes::u8;
1291 return NVVM::WGMMATypes::s32;
1292 llvm_unreachable(
"unsupported type");
1299 generateWgmmaLayout(std::optional<bool>
transpose)
const {
1306 NVVM::MMAShapeAttr generateWgmmaShape()
const {
1311 NVVM::WGMMAScaleOutAttr generateScaleOut()
const {
1313 NVVM::WGMMAScaleOut::one);
1316 NVVM::WGMMAScaleInAttr generateScaleIn()
const {
1318 NVVM::WGMMAScaleIn::one);
1344 Value iterateDescriptorA(
Value desc,
int i,
int j,
int k) {
1345 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1346 Type elemA = matrixTypeA.getElementType();
1348 int tileShapeA = matrixTypeA.getDimSize(1);
1349 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) *
byte;
1351 LLVM_DEBUG(
DBGS() <<
"\t\t[m: " << i <<
" n: " <<
j <<
" k: " << k
1352 <<
"] [wgmma descriptors] Descriptor A + "
1353 << incrementVal <<
" | \t ");
1356 return makeAdd(desc, makeI64Const(b, incrementVal));
1370 Value iterateDescriptorB(
Value desc,
int i,
int j,
int k) {
1371 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1372 Type elemB = matrixTypeB.getElementType();
1374 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1376 LLVM_DEBUG(
DBGSE() <<
"Descriptor B + " << incrementVal <<
"\n");
1379 return makeAdd(desc, makeI64Const(b, incrementVal));
1384 Value generateWgmma(
int i,
int j,
int k,
Value matrixC) {
1385 LLVM_DEBUG(
DBGS() <<
"\t wgmma."
1386 <<
"m" << wgmmaM <<
"n" << wgmmaN <<
"k" << wgmmaK
1387 <<
"(A[" << (iterationM * wgmmaM) <<
":"
1388 << (iterationM * wgmmaM) + wgmmaM <<
"]["
1389 << (iterationK * wgmmaK) <<
":"
1390 << (iterationK * wgmmaK + wgmmaK) <<
"] * "
1391 <<
" B[" << (iterationK * wgmmaK) <<
":"
1392 << (iterationK * wgmmaK + wgmmaK) <<
"][" << 0 <<
":"
1393 << wgmmaN <<
"])\n");
1395 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i,
j, k);
1396 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i,
j, k);
1398 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1399 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1401 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1402 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1404 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1405 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD,
true);
1407 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1408 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1409 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1410 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1411 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1414 op->
getContext(), NVVM::MMAIntOverflow::wrapped);
1416 return b.
create<NVVM::WgmmaMmaAsyncOp>(
1417 matrixC.
getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1418 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1423 Value generateWgmmaGroup() {
1425 b.
create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1429 for (
int i = 0; i < iterationM; ++i) {
1430 Value matrixC = b.
create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1431 for (
int j = 0;
j < iterationN; ++
j)
1432 for (
int k = 0; k < iterationK; ++k)
1433 matrixC = generateWgmma(i,
j, k, matrixC);
1434 wgmmaResults.push_back(matrixC);
1437 wgmmaResult = b.
create<LLVM::InsertValueOp>(wgmmaResult.
getType(),
1438 wgmmaResult, matrix, idx);
1446 : op(op), b(b), adaptor(adaptor) {
1448 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1449 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1450 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1451 LLVM_DEBUG(
DBGS() <<
"===--- GEMM D[" << totalM <<
"][" << totalN
1452 <<
"] += A[" << totalM <<
"][" << totalK <<
"] * B["
1453 << totalK <<
"][" << totalN <<
"] ---===\n");
1458 op.getDescriptorA().getType().getTensor().getElementType());
1461 iterationM = totalM / wgmmaM;
1462 iterationN = totalN / wgmmaN;
1463 iterationK = totalK / wgmmaK;
1471 Value generateWarpgroupMma() {
1472 b.
create<NVVM::WgmmaFenceAlignedOp>();
1473 Value wgmmaResult = generateWgmmaGroup();
1474 b.
create<NVVM::WgmmaGroupSyncAlignedOp>();
1475 b.
create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1480 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1485 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1488 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1496 struct NVGPUWarpgroupMmaStoreOpLowering
1542 auto makeConst = [&](int32_t index) ->
Value {
1545 Value c1 = makeConst(1);
1546 Value c2 = makeConst(2);
1547 Value c4 = makeConst(4);
1548 Value c8 = makeConst(8);
1549 Value c16 = makeConst(16);
1559 auto makeExtractAndStore = [&](
int i,
Value wgmmaResult,
Value x,
Value y,
1564 Value idy1 = b.
create<arith::IndexCastOp>(it, makeAdd(y, c1));
1565 Value d0 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i);
1566 Value d1 = b.
create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1572 Value laneId = b.
create<LLVM::URemOp>(i32, tidx, warpSize);
1573 Value warpId = b.
create<LLVM::UDivOp>(i32, tidx, warpSize);
1574 Value lane4Id = b.
create<LLVM::UDivOp>(i32, laneId, c4);
1575 Value lane4modId = b.
create<LLVM::URemOp>(i32, laneId, c4);
1577 Value tj = makeMul(lane4modId, c2);
1578 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1580 ti = makeAdd(ti, makeConst(offset));
1585 constexpr
unsigned numAdjacentRegisters = 2;
1587 constexpr
unsigned numStackedMatrices = 2;
1589 size_t storeCount = (structType.getBody().size() /
1590 (numStackedMatrices * numAdjacentRegisters));
1592 for (
size_t i = 0; i < numStackedMatrices; ++i) {
1593 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1594 for (
size_t j = 0;
j < storeCount; ++
j) {
1595 Value idy = makeAdd(tj, makeMul(makeConst(
j), c8));
1596 size_t structIndex = (i * numAdjacentRegisters) +
1597 (
j * (numStackedMatrices * numAdjacentRegisters));
1598 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1604 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1608 Value matriDValue = adaptor.getMatrixD();
1612 Value innerStructValue = b.
create<LLVM::ExtractValueOp>(matriDValue, idx);
1613 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1614 offset += structType.getBody().size();
1621 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1626 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1631 ->convertType(op.getMatrixC().getType())
1639 Value packStruct = b.
create<LLVM::UndefOp>(packStructType);
1644 Value structValue = b.
create<LLVM::ExtractValueOp>(packStruct, idx);
1645 for (
unsigned i = 0; i < structType.getBody().size(); ++i) {
1646 structValue = b.
create<LLVM::InsertValueOp>(
1649 innerStructs.push_back(structValue);
1653 packStruct = b.
create<LLVM::InsertValueOp>(packStruct.getType(),
1654 packStruct, matrix, idx);
1661 struct NVGPUTmaPrefetchOpLowering
1665 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1668 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1678 NVGPUMBarrierCreateLowering,
1679 NVGPUMBarrierInitLowering,
1680 NVGPUMBarrierArriveLowering,
1681 NVGPUMBarrierArriveNoCompleteLowering,
1682 NVGPUMBarrierTestWaitLowering,
1683 NVGPUMBarrierTryWaitParityLowering,
1684 NVGPUTmaAsyncLoadOpLowering,
1685 NVGPUTmaAsyncStoreOpLowering,
1686 NVGPUTmaCreateDescriptorOpLowering,
1687 NVGPUTmaPrefetchOpLowering,
1688 NVGPUMBarrierArriveExpectTxLowering,
1689 NVGPUGenerateWarpgroupDescriptorLowering,
1690 NVGPUWarpgroupMmaOpLowering,
1691 NVGPUWarpgroupMmaStoreOpLowering,
1692 NVGPUWarpgroupMmaInitAccumulatorOpLowering,
1693 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1694 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1695 NVGPUMmaSparseSyncLowering>(converter);
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static SmallVector< Value > unpackOperandVector(ImplicitLocOpBuilder &b, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...
static Value truncToI32(ImplicitLocOpBuilder &b, Value value)
GPU has 32 bit registers, this function truncates values when larger width is not needed.
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation.
constexpr int exclude4LSB
Number of bits that needs to be excluded when building matrix descriptor for wgmma operations.
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType)
Returns whether mbarrier object has shared memory address space.
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class provides support for representing a failure result, or a valid value of type T.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
ArrayRef< Type > getBody() const
Returns the list of element types contained in a non-opaque struct.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isFloat8E4M3FN() const
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isFloat8E5M2() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ kGlobalMemorySpace
Global memory space identifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
Attribute getMbarrierMemorySpace(MLIRContext *context, MBarrierGroupType barrierType)
Returns the memory space attribute of the mbarrier object.
void populateSCFStructuralTypeConversionsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef< Value > arguments) const
This class represents an efficient way to signal success or failure.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.