32template <
typename Op,
typename... Args>
33static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
34 ConversionPatternRewriter &rewriter, StringRef name,
37 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
38 ConversionPatternRewriter::InsertionGuard guard(rewriter);
39 rewriter.setInsertionPointToStart(moduleOp.getBody());
40 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
47 ConversionPatternRewriter &rewriter,
49 LLVM::LLVMFunctionType type) {
50 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
51 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
54std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
55 ConversionPatternRewriter &rewriter,
58 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
60 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
61 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
62 rewriter.getI64Type(), memRef, 2);
64 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
65 Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
66 rewriter.getIndexAttr(1));
67 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
68 for (
int64_t i = 0; i < rank; ++i) {
69 Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
71 dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
73 LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
78 return {resPtr, size};
92 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
94 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
96 virtual ~MPIImplTraits() =
default;
98 ModuleOp &getModuleOp() {
return moduleOp; }
104 virtual Value getCommWorld(Location loc,
105 ConversionPatternRewriter &rewriter) = 0;
109 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
113 virtual intptr_t getStatusIgnore() = 0;
116 virtual void *getInPlace() = 0;
120 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
125 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
126 mpi::MPI_ReductionOpEnum opAttr) = 0;
133class MPICHImplTraits :
public MPIImplTraits {
134 static constexpr int MPI_FLOAT = 0x4c00040a;
135 static constexpr int MPI_DOUBLE = 0x4c00080b;
136 static constexpr int MPI_INT8_T = 0x4c000137;
137 static constexpr int MPI_INT16_T = 0x4c000238;
138 static constexpr int MPI_INT32_T = 0x4c000439;
139 static constexpr int MPI_INT64_T = 0x4c00083a;
140 static constexpr int MPI_UINT8_T = 0x4c00013b;
141 static constexpr int MPI_UINT16_T = 0x4c00023c;
142 static constexpr int MPI_UINT32_T = 0x4c00043d;
143 static constexpr int MPI_UINT64_T = 0x4c00083e;
144 static constexpr int MPI_MAX = 0x58000001;
145 static constexpr int MPI_MIN = 0x58000002;
146 static constexpr int MPI_SUM = 0x58000003;
147 static constexpr int MPI_PROD = 0x58000004;
148 static constexpr int MPI_LAND = 0x58000005;
149 static constexpr int MPI_BAND = 0x58000006;
150 static constexpr int MPI_LOR = 0x58000007;
151 static constexpr int MPI_BOR = 0x58000008;
152 static constexpr int MPI_LXOR = 0x58000009;
153 static constexpr int MPI_BXOR = 0x5800000a;
154 static constexpr int MPI_MINLOC = 0x5800000b;
155 static constexpr int MPI_MAXLOC = 0x5800000c;
156 static constexpr int MPI_REPLACE = 0x5800000d;
157 static constexpr int MPI_NO_OP = 0x5800000e;
160 using MPIImplTraits::MPIImplTraits;
162 ~MPICHImplTraits()
override =
default;
164 Value getCommWorld(
const Location loc,
165 ConversionPatternRewriter &rewriter)
override {
166 static constexpr int MPI_COMM_WORLD = 0x44000000;
167 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
171 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
172 Value comm)
override {
173 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
176 intptr_t getStatusIgnore()
override {
return 1; }
178 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
180 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
181 Type type)
override {
185 else if (type.
isF64())
190 mtype = MPI_UINT64_T;
194 mtype = MPI_UINT32_T;
198 mtype = MPI_UINT16_T;
204 assert(
false &&
"unsupported type");
205 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
209 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
210 mpi::MPI_ReductionOpEnum opAttr)
override {
211 int32_t op = MPI_NO_OP;
213 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
216 case mpi::MPI_ReductionOpEnum::MPI_MAX:
219 case mpi::MPI_ReductionOpEnum::MPI_MIN:
222 case mpi::MPI_ReductionOpEnum::MPI_SUM:
225 case mpi::MPI_ReductionOpEnum::MPI_PROD:
228 case mpi::MPI_ReductionOpEnum::MPI_LAND:
231 case mpi::MPI_ReductionOpEnum::MPI_BAND:
234 case mpi::MPI_ReductionOpEnum::MPI_LOR:
237 case mpi::MPI_ReductionOpEnum::MPI_BOR:
240 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
243 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
246 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
249 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
252 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
256 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
263class OMPIImplTraits :
public MPIImplTraits {
264 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
265 ConversionPatternRewriter &rewriter,
267 LLVM::LLVMStructType type) {
269 return getOrDefineGlobal<LLVM::GlobalOp>(
270 getModuleOp(), loc, rewriter, name, type,
false,
271 LLVM::Linkage::External, name,
276 using MPIImplTraits::MPIImplTraits;
278 ~OMPIImplTraits()
override =
default;
280 Value getCommWorld(
const Location loc,
281 ConversionPatternRewriter &rewriter)
override {
282 auto *context = rewriter.getContext();
285 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
286 StringRef name =
"ompi_mpi_comm_world";
289 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
292 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
293 LLVM::LLVMPointerType::get(context),
294 SymbolRefAttr::get(context, name));
295 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
298 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
299 Value comm)
override {
300 return LLVM::IntToPtrOp::create(
301 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
304 intptr_t getStatusIgnore()
override {
return 0; }
306 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
308 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
309 Type type)
override {
312 mtype =
"ompi_mpi_float";
313 else if (type.
isF64())
314 mtype =
"ompi_mpi_double";
316 mtype =
"ompi_mpi_int64_t";
318 mtype =
"ompi_mpi_uint64_t";
320 mtype =
"ompi_mpi_int32_t";
322 mtype =
"ompi_mpi_uint32_t";
324 mtype =
"ompi_mpi_int16_t";
326 mtype =
"ompi_mpi_uint16_t";
328 mtype =
"ompi_mpi_int8_t";
330 mtype =
"ompi_mpi_uint8_t";
332 assert(
false &&
"unsupported type");
334 auto *context = rewriter.getContext();
337 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
339 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
341 return LLVM::AddressOfOp::create(rewriter, loc,
342 LLVM::LLVMPointerType::get(context),
343 SymbolRefAttr::get(context, mtype));
346 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
347 mpi::MPI_ReductionOpEnum opAttr)
override {
350 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
351 op =
"ompi_mpi_no_op";
353 case mpi::MPI_ReductionOpEnum::MPI_MAX:
356 case mpi::MPI_ReductionOpEnum::MPI_MIN:
359 case mpi::MPI_ReductionOpEnum::MPI_SUM:
362 case mpi::MPI_ReductionOpEnum::MPI_PROD:
363 op =
"ompi_mpi_prod";
365 case mpi::MPI_ReductionOpEnum::MPI_LAND:
366 op =
"ompi_mpi_land";
368 case mpi::MPI_ReductionOpEnum::MPI_BAND:
369 op =
"ompi_mpi_band";
371 case mpi::MPI_ReductionOpEnum::MPI_LOR:
374 case mpi::MPI_ReductionOpEnum::MPI_BOR:
377 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
378 op =
"ompi_mpi_lxor";
380 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
381 op =
"ompi_mpi_bxor";
383 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
384 op =
"ompi_mpi_minloc";
386 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
387 op =
"ompi_mpi_maxloc";
389 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
390 op =
"ompi_mpi_replace";
393 auto *context = rewriter.getContext();
396 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
398 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
400 return LLVM::AddressOfOp::create(rewriter, loc,
401 LLVM::LLVMPointerType::get(context),
402 SymbolRefAttr::get(context, op));
406std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
407 auto attr =
dlti::query(moduleOp, {
"MPI:Implementation"},
false);
409 return std::make_unique<MPICHImplTraits>(moduleOp);
410 auto strAttr = dyn_cast<StringAttr>(attr.value());
411 if (strAttr && strAttr.getValue() ==
"OpenMPI")
412 return std::make_unique<OMPIImplTraits>(moduleOp);
413 if (!strAttr || strAttr.getValue() !=
"MPICH")
414 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
415 << (strAttr ? strAttr.getValue() :
"<NULL>")
416 <<
"), defaulting to MPICH";
417 return std::make_unique<MPICHImplTraits>(moduleOp);
424struct InitOpLowering :
public ConvertOpToLLVMPattern<mpi::InitOp> {
428 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter)
const override {
430 Location loc = op.getLoc();
433 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
436 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
437 Value llvmnull = nullPtrOp.getRes();
440 auto moduleOp = op->getParentOfType<ModuleOp>();
444 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
446 LLVM::LLVMFuncOp initDecl =
450 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
461struct FinalizeOpLowering :
public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
465 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
466 ConversionPatternRewriter &rewriter)
const override {
468 Location loc = op.getLoc();
471 auto moduleOp = op->getParentOfType<ModuleOp>();
474 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
477 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
480 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{});
490struct CommWorldOpLowering :
public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
494 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter)
const override {
497 auto moduleOp = op->getParentOfType<ModuleOp>();
498 auto mpiTraits = MPIImplTraits::get(moduleOp);
500 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
510struct CommSplitOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
514 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter)
const override {
517 auto moduleOp = op->getParentOfType<ModuleOp>();
518 auto mpiTraits = MPIImplTraits::get(moduleOp);
519 Type i32 = rewriter.getI32Type();
520 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
521 Location loc = op.getLoc();
524 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
525 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
527 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
531 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), i32, i32, ptrType});
534 "MPI_Comm_split", funcType);
537 LLVM::CallOp::create(rewriter, loc, funcDecl,
539 adaptor.getKey(), outPtr.getRes()});
542 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
543 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
547 SmallVector<Value> replacements;
549 replacements.push_back(callOp.getResult());
552 replacements.push_back(res);
553 rewriter.replaceOp(op, replacements);
563struct CommRankOpLowering :
public ConvertOpToLLVMPattern<mpi::CommRankOp> {
567 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter)
const override {
570 Location loc = op.getLoc();
571 MLIRContext *context = rewriter.getContext();
572 Type i32 = rewriter.getI32Type();
575 Type ptrType = LLVM::LLVMPointerType::get(context);
578 auto moduleOp = op->getParentOfType<ModuleOp>();
580 auto mpiTraits = MPIImplTraits::get(moduleOp);
582 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
586 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
589 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
592 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
593 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
594 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
599 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
603 SmallVector<Value> replacements;
605 replacements.push_back(callOp.getResult());
608 replacements.push_back(loadedRank.getRes());
609 rewriter.replaceOp(op, replacements);
619static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
620 Location loc, Value commOrg,
622 auto i32 = rewriter.getI32Type();
623 auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
624 if (succeeded(
FoldToDLTIConst(nRanksOp,
"MPI:comm_world_size", rewriter)))
625 return nRanksOp.getSize();
626 rewriter.eraseOp(nRanksOp);
627 return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
630struct CommSizeOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
634 matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
635 ConversionPatternRewriter &rewriter)
const override {
637 Location loc = op.getLoc();
638 MLIRContext *context = rewriter.getContext();
639 Type i32 = rewriter.getI32Type();
642 Type ptrType = LLVM::LLVMPointerType::get(context);
645 auto moduleOp = op->getParentOfType<ModuleOp>();
647 auto mpiTraits = MPIImplTraits::get(moduleOp);
649 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
653 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
656 moduleOp, loc, rewriter,
"MPI_Comm_size", SizeFuncType);
659 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
660 auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
661 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
666 LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
670 SmallVector<Value> replacements;
672 replacements.push_back(callOp.getResult());
675 replacements.push_back(loadedSize.getRes());
676 rewriter.replaceOp(op, replacements);
686struct SendOpLowering :
public ConvertOpToLLVMPattern<mpi::SendOp> {
690 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter)
const override {
693 Location loc = op.getLoc();
694 MLIRContext *context = rewriter.getContext();
695 Type i32 = rewriter.getI32Type();
696 Type elemType = op.getRef().getType().getElementType();
697 int64_t rank = op.getRef().getType().getRank();
700 Type ptrType = LLVM::LLVMPointerType::get(context);
703 auto moduleOp = op->getParentOfType<ModuleOp>();
706 auto [dataPtr, size] =
707 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
708 auto mpiTraits = MPIImplTraits::get(moduleOp);
709 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
710 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
714 auto funcType = LLVM::LLVMFunctionType::get(
715 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
717 LLVM::LLVMFuncOp funcDecl =
721 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
724 adaptor.getTag(), comm});
726 rewriter.replaceOp(op, funcCall.getResult());
728 rewriter.eraseOp(op);
738struct RecvOpLowering :
public ConvertOpToLLVMPattern<mpi::RecvOp> {
742 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
743 ConversionPatternRewriter &rewriter)
const override {
745 Location loc = op.getLoc();
746 MLIRContext *context = rewriter.getContext();
747 Type i32 = rewriter.getI32Type();
748 Type i64 = rewriter.getI64Type();
749 Type elemType = op.getRef().getType().getElementType();
750 int64_t rank = op.getRef().getType().getRank();
753 Type ptrType = LLVM::LLVMPointerType::get(context);
756 auto moduleOp = op->getParentOfType<ModuleOp>();
759 auto [dataPtr, size] =
760 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
761 auto mpiTraits = MPIImplTraits::get(moduleOp);
762 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
763 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
764 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
765 mpiTraits->getStatusIgnore());
767 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
772 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.
getType(), i32,
773 i32, comm.
getType(), ptrType});
775 LLVM::LLVMFuncOp funcDecl =
779 auto funcCall = LLVM::CallOp::create(
780 rewriter, loc, funcDecl,
781 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
782 adaptor.getTag(), comm, statusIgnore});
784 rewriter.replaceOp(op, funcCall.getResult());
786 rewriter.eraseOp(op);
796struct AllGatherOpLowering :
public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
800 matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
801 ConversionPatternRewriter &rewriter)
const override {
802 Location loc = op.getLoc();
803 MLIRContext *context = rewriter.getContext();
804 Type sElemType = op.getSendbuf().getType().getElementType();
805 Type rElemType = op.getRecvbuf().getType().getElementType();
806 int64_t sRank = op.getSendbuf().getType().getRank();
807 int64_t rRank = op.getRecvbuf().getType().getRank();
808 auto [sendPtr, sendSize] =
809 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
810 auto [recvPtr, recvSize] =
811 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
813 auto moduleOp = op->getParentOfType<ModuleOp>();
814 auto mpiTraits = MPIImplTraits::get(moduleOp);
815 Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
816 Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
817 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
819 Type ptrType = LLVM::LLVMPointerType::get(context);
820 Type i32 = rewriter.getI32Type();
825 auto funcType = LLVM::LLVMFunctionType::get(
826 i32, {ptrType, i32, sDataType.
getType(), ptrType, i32,
829 LLVM::LLVMFuncOp funcDecl =
834 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
835 Value recvCountPerRank =
836 LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
840 LLVM::CallOp::create(rewriter, loc, funcDecl,
841 ValueRange{sendPtr, sendSize, sDataType, recvPtr,
842 recvCountPerRank, rDataType, comm});
845 rewriter.replaceOp(op, funcCall.getResult());
847 rewriter.eraseOp(op);
857struct AllReduceOpLowering :
public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
861 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
862 ConversionPatternRewriter &rewriter)
const override {
863 Location loc = op.getLoc();
864 MLIRContext *context = rewriter.getContext();
865 Type i32 = rewriter.getI32Type();
866 Type i64 = rewriter.getI64Type();
867 Type elemType = op.getSendbuf().getType().getElementType();
868 int64_t sRank = op.getSendbuf().getType().getRank();
869 int64_t rRank = op.getRecvbuf().getType().getRank();
872 Type ptrType = LLVM::LLVMPointerType::get(context);
873 auto moduleOp = op->getParentOfType<ModuleOp>();
874 auto mpiTraits = MPIImplTraits::get(moduleOp);
875 auto [sendPtr, sendSize] =
876 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
877 auto [recvPtr, recvSize] =
878 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
881 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
882 sendPtr = LLVM::ConstantOp::create(
884 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
885 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
888 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
889 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
890 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
894 auto funcType = LLVM::LLVMFunctionType::get(
898 LLVM::LLVMFuncOp funcDecl =
902 auto funcCall = LLVM::CallOp::create(
903 rewriter, loc, funcDecl,
904 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
907 rewriter.replaceOp(op, funcCall.getResult());
909 rewriter.eraseOp(op);
919struct ReduceScatterBlockOpLowering
920 :
public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
924 matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
925 ConversionPatternRewriter &rewriter)
const override {
926 Location loc = op.getLoc();
927 MLIRContext *context = rewriter.getContext();
928 Type i32 = rewriter.getI32Type();
929 Type i64 = rewriter.getI64Type();
930 Type elemType = op.getSendbuf().getType().getElementType();
931 int64_t sRank = op.getSendbuf().getType().getRank();
932 int64_t rRank = op.getRecvbuf().getType().getRank();
935 Type ptrType = LLVM::LLVMPointerType::get(context);
936 auto moduleOp = op->getParentOfType<ModuleOp>();
937 auto mpiTraits = MPIImplTraits::get(moduleOp);
938 auto [sendPtr, sendSize] =
939 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
940 auto [recvPtr, recvSize] =
941 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
944 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
945 sendPtr = LLVM::ConstantOp::create(
947 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
948 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
951 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
952 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
953 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
956 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
957 Value totalExpected =
958 LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
959 Value sizeIsValid = LLVM::ICmpOp::create(
960 rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
961 cf::AssertOp::create(rewriter, loc, sizeIsValid,
962 "Send buffer's size must be the receive buffer's size "
963 "times the number of ranks");
967 auto funcType = LLVM::LLVMFunctionType::get(
972 moduleOp, loc, rewriter,
"MPI_Reduce_scatter_block", funcType);
975 auto funcCall = LLVM::CallOp::create(
976 rewriter, loc, funcDecl,
977 ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
980 rewriter.replaceOp(op, funcCall.getResult());
982 rewriter.eraseOp(op);
993struct FuncToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
997 void populateConvertToLLVMConversionPatterns(
998 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
999 RewritePatternSet &patterns)
const final {
1014 converter.addConversion([](mpi::CommType type) {
1015 return IntegerType::get(type.getContext(), 64);
1017 patterns.
add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
1018 CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
1019 SendOpLowering, RecvOpLowering, AllGatherOpLowering,
1020 AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
1025 dialect->addInterfaces<FuncToLLVMDialectInterface>();
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...