31template <
typename Op,
typename... Args>
32static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
33 ConversionPatternRewriter &rewriter, StringRef name,
36 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
37 ConversionPatternRewriter::InsertionGuard guard(rewriter);
38 rewriter.setInsertionPointToStart(moduleOp.getBody());
39 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
46 ConversionPatternRewriter &rewriter,
48 LLVM::LLVMFunctionType type) {
49 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
50 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
53std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
54 ConversionPatternRewriter &rewriter,
57 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
59 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
60 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
61 rewriter.getI64Type(), memRef, 2);
63 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
64 Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
65 rewriter.getIndexAttr(1));
66 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
67 for (
int64_t i = 0; i < rank; ++i) {
68 Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
70 dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
72 LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
77 return {resPtr, size};
91 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
93 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
95 virtual ~MPIImplTraits() =
default;
97 ModuleOp &getModuleOp() {
return moduleOp; }
103 virtual Value getCommWorld(Location loc,
104 ConversionPatternRewriter &rewriter) = 0;
108 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
112 virtual intptr_t getStatusIgnore() = 0;
115 virtual void *getInPlace() = 0;
119 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
124 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
125 mpi::MPI_ReductionOpEnum opAttr) = 0;
132class MPICHImplTraits :
public MPIImplTraits {
133 static constexpr int MPI_FLOAT = 0x4c00040a;
134 static constexpr int MPI_DOUBLE = 0x4c00080b;
135 static constexpr int MPI_INT8_T = 0x4c000137;
136 static constexpr int MPI_INT16_T = 0x4c000238;
137 static constexpr int MPI_INT32_T = 0x4c000439;
138 static constexpr int MPI_INT64_T = 0x4c00083a;
139 static constexpr int MPI_UINT8_T = 0x4c00013b;
140 static constexpr int MPI_UINT16_T = 0x4c00023c;
141 static constexpr int MPI_UINT32_T = 0x4c00043d;
142 static constexpr int MPI_UINT64_T = 0x4c00083e;
143 static constexpr int MPI_MAX = 0x58000001;
144 static constexpr int MPI_MIN = 0x58000002;
145 static constexpr int MPI_SUM = 0x58000003;
146 static constexpr int MPI_PROD = 0x58000004;
147 static constexpr int MPI_LAND = 0x58000005;
148 static constexpr int MPI_BAND = 0x58000006;
149 static constexpr int MPI_LOR = 0x58000007;
150 static constexpr int MPI_BOR = 0x58000008;
151 static constexpr int MPI_LXOR = 0x58000009;
152 static constexpr int MPI_BXOR = 0x5800000a;
153 static constexpr int MPI_MINLOC = 0x5800000b;
154 static constexpr int MPI_MAXLOC = 0x5800000c;
155 static constexpr int MPI_REPLACE = 0x5800000d;
156 static constexpr int MPI_NO_OP = 0x5800000e;
159 using MPIImplTraits::MPIImplTraits;
161 ~MPICHImplTraits()
override =
default;
163 Value getCommWorld(
const Location loc,
164 ConversionPatternRewriter &rewriter)
override {
165 static constexpr int MPI_COMM_WORLD = 0x44000000;
166 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
170 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
171 Value comm)
override {
172 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
175 intptr_t getStatusIgnore()
override {
return 1; }
177 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
179 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
180 Type type)
override {
184 else if (type.
isF64())
189 mtype = MPI_UINT64_T;
193 mtype = MPI_UINT32_T;
197 mtype = MPI_UINT16_T;
203 assert(
false &&
"unsupported type");
204 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
208 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
209 mpi::MPI_ReductionOpEnum opAttr)
override {
210 int32_t op = MPI_NO_OP;
212 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
215 case mpi::MPI_ReductionOpEnum::MPI_MAX:
218 case mpi::MPI_ReductionOpEnum::MPI_MIN:
221 case mpi::MPI_ReductionOpEnum::MPI_SUM:
224 case mpi::MPI_ReductionOpEnum::MPI_PROD:
227 case mpi::MPI_ReductionOpEnum::MPI_LAND:
230 case mpi::MPI_ReductionOpEnum::MPI_BAND:
233 case mpi::MPI_ReductionOpEnum::MPI_LOR:
236 case mpi::MPI_ReductionOpEnum::MPI_BOR:
239 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
242 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
245 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
248 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
251 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
255 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
262class OMPIImplTraits :
public MPIImplTraits {
263 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
264 ConversionPatternRewriter &rewriter,
266 LLVM::LLVMStructType type) {
268 return getOrDefineGlobal<LLVM::GlobalOp>(
269 getModuleOp(), loc, rewriter, name, type,
false,
270 LLVM::Linkage::External, name,
275 using MPIImplTraits::MPIImplTraits;
277 ~OMPIImplTraits()
override =
default;
279 Value getCommWorld(
const Location loc,
280 ConversionPatternRewriter &rewriter)
override {
281 auto *context = rewriter.getContext();
284 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
285 StringRef name =
"ompi_mpi_comm_world";
288 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
291 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
292 LLVM::LLVMPointerType::get(context),
293 SymbolRefAttr::get(context, name));
294 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
297 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
298 Value comm)
override {
299 return LLVM::IntToPtrOp::create(
300 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
303 intptr_t getStatusIgnore()
override {
return 0; }
305 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
307 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
308 Type type)
override {
311 mtype =
"ompi_mpi_float";
312 else if (type.
isF64())
313 mtype =
"ompi_mpi_double";
315 mtype =
"ompi_mpi_int64_t";
317 mtype =
"ompi_mpi_uint64_t";
319 mtype =
"ompi_mpi_int32_t";
321 mtype =
"ompi_mpi_uint32_t";
323 mtype =
"ompi_mpi_int16_t";
325 mtype =
"ompi_mpi_uint16_t";
327 mtype =
"ompi_mpi_int8_t";
329 mtype =
"ompi_mpi_uint8_t";
331 assert(
false &&
"unsupported type");
333 auto *context = rewriter.getContext();
336 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
338 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
340 return LLVM::AddressOfOp::create(rewriter, loc,
341 LLVM::LLVMPointerType::get(context),
342 SymbolRefAttr::get(context, mtype));
345 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
346 mpi::MPI_ReductionOpEnum opAttr)
override {
349 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
350 op =
"ompi_mpi_no_op";
352 case mpi::MPI_ReductionOpEnum::MPI_MAX:
355 case mpi::MPI_ReductionOpEnum::MPI_MIN:
358 case mpi::MPI_ReductionOpEnum::MPI_SUM:
361 case mpi::MPI_ReductionOpEnum::MPI_PROD:
362 op =
"ompi_mpi_prod";
364 case mpi::MPI_ReductionOpEnum::MPI_LAND:
365 op =
"ompi_mpi_land";
367 case mpi::MPI_ReductionOpEnum::MPI_BAND:
368 op =
"ompi_mpi_band";
370 case mpi::MPI_ReductionOpEnum::MPI_LOR:
373 case mpi::MPI_ReductionOpEnum::MPI_BOR:
376 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
377 op =
"ompi_mpi_lxor";
379 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
380 op =
"ompi_mpi_bxor";
382 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
383 op =
"ompi_mpi_minloc";
385 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
386 op =
"ompi_mpi_maxloc";
388 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
389 op =
"ompi_mpi_replace";
392 auto *context = rewriter.getContext();
395 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
397 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
399 return LLVM::AddressOfOp::create(rewriter, loc,
400 LLVM::LLVMPointerType::get(context),
401 SymbolRefAttr::get(context, op));
405std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
406 auto attr =
dlti::query(moduleOp, {
"MPI:Implementation"},
false);
408 return std::make_unique<MPICHImplTraits>(moduleOp);
409 auto strAttr = dyn_cast<StringAttr>(attr.value());
410 if (strAttr && strAttr.getValue() ==
"OpenMPI")
411 return std::make_unique<OMPIImplTraits>(moduleOp);
412 if (!strAttr || strAttr.getValue() !=
"MPICH")
413 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
414 << (strAttr ? strAttr.getValue() :
"<NULL>")
415 <<
"), defaulting to MPICH";
416 return std::make_unique<MPICHImplTraits>(moduleOp);
423struct InitOpLowering :
public ConvertOpToLLVMPattern<mpi::InitOp> {
427 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 Location loc = op.getLoc();
432 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
435 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
436 Value llvmnull = nullPtrOp.getRes();
439 auto moduleOp = op->getParentOfType<ModuleOp>();
443 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
445 LLVM::LLVMFuncOp initDecl =
449 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
460struct FinalizeOpLowering :
public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
464 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter)
const override {
467 Location loc = op.getLoc();
470 auto moduleOp = op->getParentOfType<ModuleOp>();
473 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
476 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
479 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{});
489struct CommWorldOpLowering :
public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
493 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter)
const override {
496 auto moduleOp = op->getParentOfType<ModuleOp>();
497 auto mpiTraits = MPIImplTraits::get(moduleOp);
499 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
509struct CommSplitOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
513 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
516 auto moduleOp = op->getParentOfType<ModuleOp>();
517 auto mpiTraits = MPIImplTraits::get(moduleOp);
518 Type i32 = rewriter.getI32Type();
519 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
520 Location loc = op.getLoc();
523 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
524 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
526 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
530 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), i32, i32, ptrType});
533 "MPI_Comm_split", funcType);
536 LLVM::CallOp::create(rewriter, loc, funcDecl,
538 adaptor.getKey(), outPtr.getRes()});
541 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
542 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
546 SmallVector<Value> replacements;
548 replacements.push_back(callOp.getResult());
551 replacements.push_back(res);
552 rewriter.replaceOp(op, replacements);
562struct CommRankOpLowering :
public ConvertOpToLLVMPattern<mpi::CommRankOp> {
566 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
567 ConversionPatternRewriter &rewriter)
const override {
569 Location loc = op.getLoc();
570 MLIRContext *context = rewriter.getContext();
571 Type i32 = rewriter.getI32Type();
574 Type ptrType = LLVM::LLVMPointerType::get(context);
577 auto moduleOp = op->getParentOfType<ModuleOp>();
579 auto mpiTraits = MPIImplTraits::get(moduleOp);
581 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
585 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
588 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
591 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
592 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
593 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
598 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
602 SmallVector<Value> replacements;
604 replacements.push_back(callOp.getResult());
607 replacements.push_back(loadedRank.getRes());
608 rewriter.replaceOp(op, replacements);
618static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
619 Location loc, Value commOrg,
621 auto i32 = rewriter.getI32Type();
622 auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
623 if (succeeded(
FoldToDLTIConst(nRanksOp,
"MPI:comm_world_size", rewriter)))
624 return nRanksOp.getSize();
625 rewriter.eraseOp(nRanksOp);
626 return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
629struct CommSizeOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
633 matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
634 ConversionPatternRewriter &rewriter)
const override {
636 Location loc = op.getLoc();
637 MLIRContext *context = rewriter.getContext();
638 Type i32 = rewriter.getI32Type();
641 Type ptrType = LLVM::LLVMPointerType::get(context);
644 auto moduleOp = op->getParentOfType<ModuleOp>();
646 auto mpiTraits = MPIImplTraits::get(moduleOp);
648 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
652 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
655 moduleOp, loc, rewriter,
"MPI_Comm_size", SizeFuncType);
658 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
659 auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
660 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
665 LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
669 SmallVector<Value> replacements;
671 replacements.push_back(callOp.getResult());
674 replacements.push_back(loadedSize.getRes());
675 rewriter.replaceOp(op, replacements);
685struct SendOpLowering :
public ConvertOpToLLVMPattern<mpi::SendOp> {
689 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
690 ConversionPatternRewriter &rewriter)
const override {
692 Location loc = op.getLoc();
693 MLIRContext *context = rewriter.getContext();
694 Type i32 = rewriter.getI32Type();
695 Type elemType = op.getRef().getType().getElementType();
696 int64_t rank = op.getRef().getType().getRank();
699 Type ptrType = LLVM::LLVMPointerType::get(context);
702 auto moduleOp = op->getParentOfType<ModuleOp>();
705 auto [dataPtr, size] =
706 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
707 auto mpiTraits = MPIImplTraits::get(moduleOp);
708 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
709 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
713 auto funcType = LLVM::LLVMFunctionType::get(
714 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
716 LLVM::LLVMFuncOp funcDecl =
720 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
723 adaptor.getTag(), comm});
725 rewriter.replaceOp(op, funcCall.getResult());
727 rewriter.eraseOp(op);
737struct RecvOpLowering :
public ConvertOpToLLVMPattern<mpi::RecvOp> {
741 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
742 ConversionPatternRewriter &rewriter)
const override {
744 Location loc = op.getLoc();
745 MLIRContext *context = rewriter.getContext();
746 Type i32 = rewriter.getI32Type();
747 Type i64 = rewriter.getI64Type();
748 Type elemType = op.getRef().getType().getElementType();
749 int64_t rank = op.getRef().getType().getRank();
752 Type ptrType = LLVM::LLVMPointerType::get(context);
755 auto moduleOp = op->getParentOfType<ModuleOp>();
758 auto [dataPtr, size] =
759 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
760 auto mpiTraits = MPIImplTraits::get(moduleOp);
761 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
762 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
763 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
764 mpiTraits->getStatusIgnore());
766 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
771 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.
getType(), i32,
772 i32, comm.
getType(), ptrType});
774 LLVM::LLVMFuncOp funcDecl =
778 auto funcCall = LLVM::CallOp::create(
779 rewriter, loc, funcDecl,
780 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
781 adaptor.getTag(), comm, statusIgnore});
783 rewriter.replaceOp(op, funcCall.getResult());
785 rewriter.eraseOp(op);
795struct AllGatherOpLowering :
public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
799 matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter)
const override {
801 Location loc = op.getLoc();
802 MLIRContext *context = rewriter.getContext();
803 Type sElemType = op.getSendbuf().getType().getElementType();
804 Type rElemType = op.getRecvbuf().getType().getElementType();
805 int64_t sRank = op.getSendbuf().getType().getRank();
806 int64_t rRank = op.getRecvbuf().getType().getRank();
807 auto [sendPtr, sendSize] =
808 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
809 auto [recvPtr, recvSize] =
810 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
812 auto moduleOp = op->getParentOfType<ModuleOp>();
813 auto mpiTraits = MPIImplTraits::get(moduleOp);
814 Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
815 Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
816 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
818 Type ptrType = LLVM::LLVMPointerType::get(context);
819 Type i32 = rewriter.getI32Type();
824 auto funcType = LLVM::LLVMFunctionType::get(
825 i32, {ptrType, i32, sDataType.
getType(), ptrType, i32,
828 LLVM::LLVMFuncOp funcDecl =
833 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
834 Value recvCountPerRank =
835 LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
839 LLVM::CallOp::create(rewriter, loc, funcDecl,
840 ValueRange{sendPtr, sendSize, sDataType, recvPtr,
841 recvCountPerRank, rDataType, comm});
844 rewriter.replaceOp(op, funcCall.getResult());
846 rewriter.eraseOp(op);
856struct AllReduceOpLowering :
public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
860 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter)
const override {
862 Location loc = op.getLoc();
863 MLIRContext *context = rewriter.getContext();
864 Type i32 = rewriter.getI32Type();
865 Type i64 = rewriter.getI64Type();
866 Type elemType = op.getSendbuf().getType().getElementType();
867 int64_t sRank = op.getSendbuf().getType().getRank();
868 int64_t rRank = op.getRecvbuf().getType().getRank();
871 Type ptrType = LLVM::LLVMPointerType::get(context);
872 auto moduleOp = op->getParentOfType<ModuleOp>();
873 auto mpiTraits = MPIImplTraits::get(moduleOp);
874 auto [sendPtr, sendSize] =
875 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
876 auto [recvPtr, recvSize] =
877 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
880 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
881 sendPtr = LLVM::ConstantOp::create(
883 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
884 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
887 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
888 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
889 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
893 auto funcType = LLVM::LLVMFunctionType::get(
897 LLVM::LLVMFuncOp funcDecl =
901 auto funcCall = LLVM::CallOp::create(
902 rewriter, loc, funcDecl,
903 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
906 rewriter.replaceOp(op, funcCall.getResult());
908 rewriter.eraseOp(op);
919struct FuncToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
923 void populateConvertToLLVMConversionPatterns(
924 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
925 RewritePatternSet &
patterns)
const final {
940 converter.addConversion([](mpi::CommType type) {
941 return IntegerType::get(type.getContext(), 64);
943 patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
944 CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
945 SendOpLowering, RecvOpLowering, AllGatherOpLowering,
946 AllReduceOpLowering>(converter);
951 dialect->addInterfaces<FuncToLLVMDialectInterface>();
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...