30template <
typename Op,
typename... Args>
31static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
32 ConversionPatternRewriter &rewriter, StringRef name,
35 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
37 rewriter.setInsertionPointToStart(moduleOp.getBody());
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
45 ConversionPatternRewriter &rewriter,
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
53 ConversionPatternRewriter &rewriter,
55 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
59 rewriter.getI64Type(), memRef, 2);
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
63 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
70 return {resPtr, size};
84 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
88 virtual ~MPIImplTraits() =
default;
90 ModuleOp &getModuleOp() {
return moduleOp; }
96 virtual Value getCommWorld(
const Location loc,
97 ConversionPatternRewriter &rewriter) = 0;
101 virtual Value castComm(
const Location loc,
102 ConversionPatternRewriter &rewriter, Value comm) = 0;
105 virtual intptr_t getStatusIgnore() = 0;
108 virtual void *getInPlace() = 0;
112 virtual Value getDataType(
const Location loc,
113 ConversionPatternRewriter &rewriter, Type type) = 0;
117 virtual Value getMPIOp(
const Location loc,
118 ConversionPatternRewriter &rewriter,
119 mpi::MPI_ReductionOpEnum opAttr) = 0;
126class MPICHImplTraits :
public MPIImplTraits {
127 static constexpr int MPI_FLOAT = 0x4c00040a;
128 static constexpr int MPI_DOUBLE = 0x4c00080b;
129 static constexpr int MPI_INT8_T = 0x4c000137;
130 static constexpr int MPI_INT16_T = 0x4c000238;
131 static constexpr int MPI_INT32_T = 0x4c000439;
132 static constexpr int MPI_INT64_T = 0x4c00083a;
133 static constexpr int MPI_UINT8_T = 0x4c00013b;
134 static constexpr int MPI_UINT16_T = 0x4c00023c;
135 static constexpr int MPI_UINT32_T = 0x4c00043d;
136 static constexpr int MPI_UINT64_T = 0x4c00083e;
137 static constexpr int MPI_MAX = 0x58000001;
138 static constexpr int MPI_MIN = 0x58000002;
139 static constexpr int MPI_SUM = 0x58000003;
140 static constexpr int MPI_PROD = 0x58000004;
141 static constexpr int MPI_LAND = 0x58000005;
142 static constexpr int MPI_BAND = 0x58000006;
143 static constexpr int MPI_LOR = 0x58000007;
144 static constexpr int MPI_BOR = 0x58000008;
145 static constexpr int MPI_LXOR = 0x58000009;
146 static constexpr int MPI_BXOR = 0x5800000a;
147 static constexpr int MPI_MINLOC = 0x5800000b;
148 static constexpr int MPI_MAXLOC = 0x5800000c;
149 static constexpr int MPI_REPLACE = 0x5800000d;
150 static constexpr int MPI_NO_OP = 0x5800000e;
153 using MPIImplTraits::MPIImplTraits;
155 ~MPICHImplTraits()
override =
default;
157 Value getCommWorld(
const Location loc,
158 ConversionPatternRewriter &rewriter)
override {
159 static constexpr int MPI_COMM_WORLD = 0x44000000;
160 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
164 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
165 Value comm)
override {
166 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
169 intptr_t getStatusIgnore()
override {
return 1; }
171 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
173 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
174 Type type)
override {
178 else if (type.
isF64())
183 mtype = MPI_UINT64_T;
187 mtype = MPI_UINT32_T;
191 mtype = MPI_UINT16_T;
197 assert(
false &&
"unsupported type");
198 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
202 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
203 mpi::MPI_ReductionOpEnum opAttr)
override {
204 int32_t op = MPI_NO_OP;
206 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
209 case mpi::MPI_ReductionOpEnum::MPI_MAX:
212 case mpi::MPI_ReductionOpEnum::MPI_MIN:
215 case mpi::MPI_ReductionOpEnum::MPI_SUM:
218 case mpi::MPI_ReductionOpEnum::MPI_PROD:
221 case mpi::MPI_ReductionOpEnum::MPI_LAND:
224 case mpi::MPI_ReductionOpEnum::MPI_BAND:
227 case mpi::MPI_ReductionOpEnum::MPI_LOR:
230 case mpi::MPI_ReductionOpEnum::MPI_BOR:
233 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
236 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
239 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
242 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
245 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
249 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
256class OMPIImplTraits :
public MPIImplTraits {
257 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
258 ConversionPatternRewriter &rewriter,
260 LLVM::LLVMStructType type) {
262 return getOrDefineGlobal<LLVM::GlobalOp>(
263 getModuleOp(), loc, rewriter, name, type,
false,
264 LLVM::Linkage::External, name,
269 using MPIImplTraits::MPIImplTraits;
271 ~OMPIImplTraits()
override =
default;
273 Value getCommWorld(
const Location loc,
274 ConversionPatternRewriter &rewriter)
override {
275 auto *context = rewriter.getContext();
278 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
279 StringRef name =
"ompi_mpi_comm_world";
282 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
285 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
286 LLVM::LLVMPointerType::get(context),
287 SymbolRefAttr::get(context, name));
288 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
291 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
292 Value comm)
override {
293 return LLVM::IntToPtrOp::create(
294 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
297 intptr_t getStatusIgnore()
override {
return 0; }
299 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
301 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
302 Type type)
override {
305 mtype =
"ompi_mpi_float";
306 else if (type.
isF64())
307 mtype =
"ompi_mpi_double";
309 mtype =
"ompi_mpi_int64_t";
311 mtype =
"ompi_mpi_uint64_t";
313 mtype =
"ompi_mpi_int32_t";
315 mtype =
"ompi_mpi_uint32_t";
317 mtype =
"ompi_mpi_int16_t";
319 mtype =
"ompi_mpi_uint16_t";
321 mtype =
"ompi_mpi_int8_t";
323 mtype =
"ompi_mpi_uint8_t";
325 assert(
false &&
"unsupported type");
327 auto *context = rewriter.getContext();
330 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
332 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
334 return LLVM::AddressOfOp::create(rewriter, loc,
335 LLVM::LLVMPointerType::get(context),
336 SymbolRefAttr::get(context, mtype));
339 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
340 mpi::MPI_ReductionOpEnum opAttr)
override {
343 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
344 op =
"ompi_mpi_no_op";
346 case mpi::MPI_ReductionOpEnum::MPI_MAX:
349 case mpi::MPI_ReductionOpEnum::MPI_MIN:
352 case mpi::MPI_ReductionOpEnum::MPI_SUM:
355 case mpi::MPI_ReductionOpEnum::MPI_PROD:
356 op =
"ompi_mpi_prod";
358 case mpi::MPI_ReductionOpEnum::MPI_LAND:
359 op =
"ompi_mpi_land";
361 case mpi::MPI_ReductionOpEnum::MPI_BAND:
362 op =
"ompi_mpi_band";
364 case mpi::MPI_ReductionOpEnum::MPI_LOR:
367 case mpi::MPI_ReductionOpEnum::MPI_BOR:
370 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
371 op =
"ompi_mpi_lxor";
373 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
374 op =
"ompi_mpi_bxor";
376 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
377 op =
"ompi_mpi_minloc";
379 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
380 op =
"ompi_mpi_maxloc";
382 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
383 op =
"ompi_mpi_replace";
386 auto *context = rewriter.getContext();
389 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
391 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
393 return LLVM::AddressOfOp::create(rewriter, loc,
394 LLVM::LLVMPointerType::get(context),
395 SymbolRefAttr::get(context, op));
399std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
400 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
402 return std::make_unique<MPICHImplTraits>(moduleOp);
403 auto strAttr = dyn_cast<StringAttr>(attr.value());
404 if (strAttr && strAttr.getValue() ==
"OpenMPI")
405 return std::make_unique<OMPIImplTraits>(moduleOp);
406 if (!strAttr || strAttr.getValue() !=
"MPICH")
407 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
408 << (strAttr ? strAttr.getValue() :
"<NULL>")
409 <<
"), defaulting to MPICH";
410 return std::make_unique<MPICHImplTraits>(moduleOp);
417struct InitOpLowering :
public ConvertOpToLLVMPattern<mpi::InitOp> {
421 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter)
const override {
423 Location loc = op.getLoc();
426 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
429 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
430 Value llvmnull = nullPtrOp.getRes();
433 auto moduleOp = op->getParentOfType<ModuleOp>();
437 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
439 LLVM::LLVMFuncOp initDecl =
443 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
454struct FinalizeOpLowering :
public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
458 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
459 ConversionPatternRewriter &rewriter)
const override {
461 Location loc = op.getLoc();
464 auto moduleOp = op->getParentOfType<ModuleOp>();
467 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
470 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
473 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{});
483struct CommWorldOpLowering :
public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
487 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter)
const override {
490 auto moduleOp = op->getParentOfType<ModuleOp>();
491 auto mpiTraits = MPIImplTraits::get(moduleOp);
493 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
503struct CommSplitOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
507 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
508 ConversionPatternRewriter &rewriter)
const override {
510 auto moduleOp = op->getParentOfType<ModuleOp>();
511 auto mpiTraits = MPIImplTraits::get(moduleOp);
512 Type i32 = rewriter.getI32Type();
513 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
514 Location loc = op.getLoc();
517 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
518 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
520 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
524 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), i32, i32, ptrType});
527 "MPI_Comm_split", funcType);
530 LLVM::CallOp::create(rewriter, loc, funcDecl,
532 adaptor.getKey(), outPtr.getRes()});
535 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
536 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
540 SmallVector<Value> replacements;
542 replacements.push_back(callOp.getResult());
545 replacements.push_back(res);
546 rewriter.replaceOp(op, replacements);
556struct CommRankOpLowering :
public ConvertOpToLLVMPattern<mpi::CommRankOp> {
560 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter)
const override {
563 Location loc = op.getLoc();
564 MLIRContext *context = rewriter.getContext();
565 Type i32 = rewriter.getI32Type();
568 Type ptrType = LLVM::LLVMPointerType::get(context);
571 auto moduleOp = op->getParentOfType<ModuleOp>();
573 auto mpiTraits = MPIImplTraits::get(moduleOp);
575 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
579 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
582 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
585 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
586 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
587 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
592 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
596 SmallVector<Value> replacements;
598 replacements.push_back(callOp.getResult());
601 replacements.push_back(loadedRank.getRes());
602 rewriter.replaceOp(op, replacements);
612struct SendOpLowering :
public ConvertOpToLLVMPattern<mpi::SendOp> {
616 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
617 ConversionPatternRewriter &rewriter)
const override {
619 Location loc = op.getLoc();
620 MLIRContext *context = rewriter.getContext();
621 Type i32 = rewriter.getI32Type();
622 Type elemType = op.getRef().getType().getElementType();
625 Type ptrType = LLVM::LLVMPointerType::get(context);
628 auto moduleOp = op->getParentOfType<ModuleOp>();
631 auto [dataPtr, size] =
632 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
633 auto mpiTraits = MPIImplTraits::get(moduleOp);
634 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
635 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
639 auto funcType = LLVM::LLVMFunctionType::get(
640 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
642 LLVM::LLVMFuncOp funcDecl =
646 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
649 adaptor.getTag(), comm});
651 rewriter.replaceOp(op, funcCall.getResult());
653 rewriter.eraseOp(op);
663struct RecvOpLowering :
public ConvertOpToLLVMPattern<mpi::RecvOp> {
667 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter)
const override {
670 Location loc = op.getLoc();
671 MLIRContext *context = rewriter.getContext();
672 Type i32 = rewriter.getI32Type();
673 Type i64 = rewriter.getI64Type();
674 Type elemType = op.getRef().getType().getElementType();
677 Type ptrType = LLVM::LLVMPointerType::get(context);
680 auto moduleOp = op->getParentOfType<ModuleOp>();
683 auto [dataPtr, size] =
684 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
685 auto mpiTraits = MPIImplTraits::get(moduleOp);
686 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
687 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
688 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
689 mpiTraits->getStatusIgnore());
691 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
696 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.
getType(), i32,
697 i32, comm.
getType(), ptrType});
699 LLVM::LLVMFuncOp funcDecl =
703 auto funcCall = LLVM::CallOp::create(
704 rewriter, loc, funcDecl,
705 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
706 adaptor.getTag(), comm, statusIgnore});
708 rewriter.replaceOp(op, funcCall.getResult());
710 rewriter.eraseOp(op);
720struct AllReduceOpLowering :
public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
724 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
725 ConversionPatternRewriter &rewriter)
const override {
726 Location loc = op.getLoc();
727 MLIRContext *context = rewriter.getContext();
728 Type i32 = rewriter.getI32Type();
729 Type i64 = rewriter.getI64Type();
730 Type elemType = op.getSendbuf().getType().getElementType();
733 Type ptrType = LLVM::LLVMPointerType::get(context);
734 auto moduleOp = op->getParentOfType<ModuleOp>();
735 auto mpiTraits = MPIImplTraits::get(moduleOp);
736 auto [sendPtr, sendSize] =
737 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
738 auto [recvPtr, recvSize] =
739 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
742 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
743 sendPtr = LLVM::ConstantOp::create(
745 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
746 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
749 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
750 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
751 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
755 auto funcType = LLVM::LLVMFunctionType::get(
759 LLVM::LLVMFuncOp funcDecl =
763 auto funcCall = LLVM::CallOp::create(
764 rewriter, loc, funcDecl,
765 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
768 rewriter.replaceOp(op, funcCall.getResult());
770 rewriter.eraseOp(op);
781struct FuncToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
785 void populateConvertToLLVMConversionPatterns(
786 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
787 RewritePatternSet &
patterns)
const final {
802 converter.addConversion([](mpi::CommType type) {
803 return IntegerType::get(type.getContext(), 64);
805 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
806 FinalizeOpLowering, InitOpLowering, SendOpLowering,
807 RecvOpLowering, AllReduceOpLowering>(converter);
812 dialect->addInterfaces<FuncToLLVMDialectInterface>();
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...