30template <
typename Op,
typename... Args>
31static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
32 ConversionPatternRewriter &rewriter, StringRef name,
35 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
37 rewriter.setInsertionPointToStart(moduleOp.getBody());
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
45 ConversionPatternRewriter &rewriter,
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
53 ConversionPatternRewriter &rewriter,
55 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
59 rewriter.getI64Type(), memRef, 2);
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
63 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
70 return {resPtr, size};
84 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
88 virtual ~MPIImplTraits() =
default;
90 ModuleOp &getModuleOp() {
return moduleOp; }
96 virtual Value getCommWorld(Location loc,
97 ConversionPatternRewriter &rewriter) = 0;
101 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
105 virtual intptr_t getStatusIgnore() = 0;
108 virtual void *getInPlace() = 0;
112 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
117 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
118 mpi::MPI_ReductionOpEnum opAttr) = 0;
125class MPICHImplTraits :
public MPIImplTraits {
126 static constexpr int MPI_FLOAT = 0x4c00040a;
127 static constexpr int MPI_DOUBLE = 0x4c00080b;
128 static constexpr int MPI_INT8_T = 0x4c000137;
129 static constexpr int MPI_INT16_T = 0x4c000238;
130 static constexpr int MPI_INT32_T = 0x4c000439;
131 static constexpr int MPI_INT64_T = 0x4c00083a;
132 static constexpr int MPI_UINT8_T = 0x4c00013b;
133 static constexpr int MPI_UINT16_T = 0x4c00023c;
134 static constexpr int MPI_UINT32_T = 0x4c00043d;
135 static constexpr int MPI_UINT64_T = 0x4c00083e;
136 static constexpr int MPI_MAX = 0x58000001;
137 static constexpr int MPI_MIN = 0x58000002;
138 static constexpr int MPI_SUM = 0x58000003;
139 static constexpr int MPI_PROD = 0x58000004;
140 static constexpr int MPI_LAND = 0x58000005;
141 static constexpr int MPI_BAND = 0x58000006;
142 static constexpr int MPI_LOR = 0x58000007;
143 static constexpr int MPI_BOR = 0x58000008;
144 static constexpr int MPI_LXOR = 0x58000009;
145 static constexpr int MPI_BXOR = 0x5800000a;
146 static constexpr int MPI_MINLOC = 0x5800000b;
147 static constexpr int MPI_MAXLOC = 0x5800000c;
148 static constexpr int MPI_REPLACE = 0x5800000d;
149 static constexpr int MPI_NO_OP = 0x5800000e;
152 using MPIImplTraits::MPIImplTraits;
154 ~MPICHImplTraits()
override =
default;
156 Value getCommWorld(
const Location loc,
157 ConversionPatternRewriter &rewriter)
override {
158 static constexpr int MPI_COMM_WORLD = 0x44000000;
159 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
163 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
164 Value comm)
override {
165 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
168 intptr_t getStatusIgnore()
override {
return 1; }
170 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
172 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
173 Type type)
override {
177 else if (type.
isF64())
182 mtype = MPI_UINT64_T;
186 mtype = MPI_UINT32_T;
190 mtype = MPI_UINT16_T;
196 assert(
false &&
"unsupported type");
197 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
201 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
202 mpi::MPI_ReductionOpEnum opAttr)
override {
203 int32_t op = MPI_NO_OP;
205 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
208 case mpi::MPI_ReductionOpEnum::MPI_MAX:
211 case mpi::MPI_ReductionOpEnum::MPI_MIN:
214 case mpi::MPI_ReductionOpEnum::MPI_SUM:
217 case mpi::MPI_ReductionOpEnum::MPI_PROD:
220 case mpi::MPI_ReductionOpEnum::MPI_LAND:
223 case mpi::MPI_ReductionOpEnum::MPI_BAND:
226 case mpi::MPI_ReductionOpEnum::MPI_LOR:
229 case mpi::MPI_ReductionOpEnum::MPI_BOR:
232 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
235 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
238 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
241 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
244 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
248 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
255class OMPIImplTraits :
public MPIImplTraits {
256 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
257 ConversionPatternRewriter &rewriter,
259 LLVM::LLVMStructType type) {
261 return getOrDefineGlobal<LLVM::GlobalOp>(
262 getModuleOp(), loc, rewriter, name, type,
false,
263 LLVM::Linkage::External, name,
268 using MPIImplTraits::MPIImplTraits;
270 ~OMPIImplTraits()
override =
default;
272 Value getCommWorld(
const Location loc,
273 ConversionPatternRewriter &rewriter)
override {
274 auto *context = rewriter.getContext();
277 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
278 StringRef name =
"ompi_mpi_comm_world";
281 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
284 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
285 LLVM::LLVMPointerType::get(context),
286 SymbolRefAttr::get(context, name));
287 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
290 Value castComm(
const Location loc, ConversionPatternRewriter &rewriter,
291 Value comm)
override {
292 return LLVM::IntToPtrOp::create(
293 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
296 intptr_t getStatusIgnore()
override {
return 0; }
298 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
300 Value getDataType(
const Location loc, ConversionPatternRewriter &rewriter,
301 Type type)
override {
304 mtype =
"ompi_mpi_float";
305 else if (type.
isF64())
306 mtype =
"ompi_mpi_double";
308 mtype =
"ompi_mpi_int64_t";
310 mtype =
"ompi_mpi_uint64_t";
312 mtype =
"ompi_mpi_int32_t";
314 mtype =
"ompi_mpi_uint32_t";
316 mtype =
"ompi_mpi_int16_t";
318 mtype =
"ompi_mpi_uint16_t";
320 mtype =
"ompi_mpi_int8_t";
322 mtype =
"ompi_mpi_uint8_t";
324 assert(
false &&
"unsupported type");
326 auto *context = rewriter.getContext();
329 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
331 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
333 return LLVM::AddressOfOp::create(rewriter, loc,
334 LLVM::LLVMPointerType::get(context),
335 SymbolRefAttr::get(context, mtype));
338 Value getMPIOp(
const Location loc, ConversionPatternRewriter &rewriter,
339 mpi::MPI_ReductionOpEnum opAttr)
override {
342 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
343 op =
"ompi_mpi_no_op";
345 case mpi::MPI_ReductionOpEnum::MPI_MAX:
348 case mpi::MPI_ReductionOpEnum::MPI_MIN:
351 case mpi::MPI_ReductionOpEnum::MPI_SUM:
354 case mpi::MPI_ReductionOpEnum::MPI_PROD:
355 op =
"ompi_mpi_prod";
357 case mpi::MPI_ReductionOpEnum::MPI_LAND:
358 op =
"ompi_mpi_land";
360 case mpi::MPI_ReductionOpEnum::MPI_BAND:
361 op =
"ompi_mpi_band";
363 case mpi::MPI_ReductionOpEnum::MPI_LOR:
366 case mpi::MPI_ReductionOpEnum::MPI_BOR:
369 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
370 op =
"ompi_mpi_lxor";
372 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
373 op =
"ompi_mpi_bxor";
375 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
376 op =
"ompi_mpi_minloc";
378 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
379 op =
"ompi_mpi_maxloc";
381 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
382 op =
"ompi_mpi_replace";
385 auto *context = rewriter.getContext();
388 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
390 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
392 return LLVM::AddressOfOp::create(rewriter, loc,
393 LLVM::LLVMPointerType::get(context),
394 SymbolRefAttr::get(context, op));
398std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
399 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
401 return std::make_unique<MPICHImplTraits>(moduleOp);
402 auto strAttr = dyn_cast<StringAttr>(attr.value());
403 if (strAttr && strAttr.getValue() ==
"OpenMPI")
404 return std::make_unique<OMPIImplTraits>(moduleOp);
405 if (!strAttr || strAttr.getValue() !=
"MPICH")
406 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
407 << (strAttr ? strAttr.getValue() :
"<NULL>")
408 <<
"), defaulting to MPICH";
409 return std::make_unique<MPICHImplTraits>(moduleOp);
416struct InitOpLowering :
public ConvertOpToLLVMPattern<mpi::InitOp> {
420 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override {
422 Location loc = op.getLoc();
425 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
428 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
429 Value llvmnull = nullPtrOp.getRes();
432 auto moduleOp = op->getParentOfType<ModuleOp>();
436 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
438 LLVM::LLVMFuncOp initDecl =
442 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
453struct FinalizeOpLowering :
public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
457 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
458 ConversionPatternRewriter &rewriter)
const override {
460 Location loc = op.getLoc();
463 auto moduleOp = op->getParentOfType<ModuleOp>();
466 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
469 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
472 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{});
482struct CommWorldOpLowering :
public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
486 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
487 ConversionPatternRewriter &rewriter)
const override {
489 auto moduleOp = op->getParentOfType<ModuleOp>();
490 auto mpiTraits = MPIImplTraits::get(moduleOp);
492 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
502struct CommSplitOpLowering :
public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
506 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
507 ConversionPatternRewriter &rewriter)
const override {
509 auto moduleOp = op->getParentOfType<ModuleOp>();
510 auto mpiTraits = MPIImplTraits::get(moduleOp);
511 Type i32 = rewriter.getI32Type();
512 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
513 Location loc = op.getLoc();
516 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
517 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
519 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
523 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), i32, i32, ptrType});
526 "MPI_Comm_split", funcType);
529 LLVM::CallOp::create(rewriter, loc, funcDecl,
531 adaptor.getKey(), outPtr.getRes()});
534 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
535 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
539 SmallVector<Value> replacements;
541 replacements.push_back(callOp.getResult());
544 replacements.push_back(res);
545 rewriter.replaceOp(op, replacements);
555struct CommRankOpLowering :
public ConvertOpToLLVMPattern<mpi::CommRankOp> {
559 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
562 Location loc = op.getLoc();
563 MLIRContext *context = rewriter.getContext();
564 Type i32 = rewriter.getI32Type();
567 Type ptrType = LLVM::LLVMPointerType::get(context);
570 auto moduleOp = op->getParentOfType<ModuleOp>();
572 auto mpiTraits = MPIImplTraits::get(moduleOp);
574 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
578 LLVM::LLVMFunctionType::get(i32, {comm.
getType(), ptrType});
581 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
584 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
585 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
586 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
591 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
595 SmallVector<Value> replacements;
597 replacements.push_back(callOp.getResult());
600 replacements.push_back(loadedRank.getRes());
601 rewriter.replaceOp(op, replacements);
611struct SendOpLowering :
public ConvertOpToLLVMPattern<mpi::SendOp> {
615 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter)
const override {
618 Location loc = op.getLoc();
619 MLIRContext *context = rewriter.getContext();
620 Type i32 = rewriter.getI32Type();
621 Type elemType = op.getRef().getType().getElementType();
624 Type ptrType = LLVM::LLVMPointerType::get(context);
627 auto moduleOp = op->getParentOfType<ModuleOp>();
630 auto [dataPtr, size] =
631 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
632 auto mpiTraits = MPIImplTraits::get(moduleOp);
633 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
634 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
638 auto funcType = LLVM::LLVMFunctionType::get(
639 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
641 LLVM::LLVMFuncOp funcDecl =
645 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
648 adaptor.getTag(), comm});
650 rewriter.replaceOp(op, funcCall.getResult());
652 rewriter.eraseOp(op);
662struct RecvOpLowering :
public ConvertOpToLLVMPattern<mpi::RecvOp> {
666 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter)
const override {
669 Location loc = op.getLoc();
670 MLIRContext *context = rewriter.getContext();
671 Type i32 = rewriter.getI32Type();
672 Type i64 = rewriter.getI64Type();
673 Type elemType = op.getRef().getType().getElementType();
676 Type ptrType = LLVM::LLVMPointerType::get(context);
679 auto moduleOp = op->getParentOfType<ModuleOp>();
682 auto [dataPtr, size] =
683 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
684 auto mpiTraits = MPIImplTraits::get(moduleOp);
685 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
686 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
687 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
688 mpiTraits->getStatusIgnore());
690 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
695 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.
getType(), i32,
696 i32, comm.
getType(), ptrType});
698 LLVM::LLVMFuncOp funcDecl =
702 auto funcCall = LLVM::CallOp::create(
703 rewriter, loc, funcDecl,
704 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
705 adaptor.getTag(), comm, statusIgnore});
707 rewriter.replaceOp(op, funcCall.getResult());
709 rewriter.eraseOp(op);
719struct AllReduceOpLowering :
public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
723 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter)
const override {
725 Location loc = op.getLoc();
726 MLIRContext *context = rewriter.getContext();
727 Type i32 = rewriter.getI32Type();
728 Type i64 = rewriter.getI64Type();
729 Type elemType = op.getSendbuf().getType().getElementType();
732 Type ptrType = LLVM::LLVMPointerType::get(context);
733 auto moduleOp = op->getParentOfType<ModuleOp>();
734 auto mpiTraits = MPIImplTraits::get(moduleOp);
735 auto [sendPtr, sendSize] =
736 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
737 auto [recvPtr, recvSize] =
738 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
741 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
742 sendPtr = LLVM::ConstantOp::create(
744 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
745 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
748 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
749 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
750 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
754 auto funcType = LLVM::LLVMFunctionType::get(
758 LLVM::LLVMFuncOp funcDecl =
762 auto funcCall = LLVM::CallOp::create(
763 rewriter, loc, funcDecl,
764 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
767 rewriter.replaceOp(op, funcCall.getResult());
769 rewriter.eraseOp(op);
780struct FuncToLLVMDialectInterface :
public ConvertToLLVMPatternInterface {
784 void populateConvertToLLVMConversionPatterns(
785 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
786 RewritePatternSet &
patterns)
const final {
801 converter.addConversion([](mpi::CommType type) {
802 return IntegerType::get(type.getContext(), 64);
804 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
805 FinalizeOpLowering, InitOpLowering, SendOpLowering,
806 RecvOpLowering, AllReduceOpLowering>(converter);
811 dialect->addInterfaces<FuncToLLVMDialectInterface>();
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...