30 template <
typename Op,
typename... Args>
31 static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
35 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52 std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
63 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.
getI32Type(), size);
70 return {resPtr, size};
84 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
88 virtual ~MPIImplTraits() =
default;
90 ModuleOp &getModuleOp() {
return moduleOp; }
105 virtual intptr_t getStatusIgnore() = 0;
108 virtual void *getInPlace() = 0;
119 mpi::MPI_ReductionOpEnum opAttr) = 0;
126 class MPICHImplTraits :
public MPIImplTraits {
127 static constexpr
int MPI_FLOAT = 0x4c00040a;
128 static constexpr
int MPI_DOUBLE = 0x4c00080b;
129 static constexpr
int MPI_INT8_T = 0x4c000137;
130 static constexpr
int MPI_INT16_T = 0x4c000238;
131 static constexpr
int MPI_INT32_T = 0x4c000439;
132 static constexpr
int MPI_INT64_T = 0x4c00083a;
133 static constexpr
int MPI_UINT8_T = 0x4c00013b;
134 static constexpr
int MPI_UINT16_T = 0x4c00023c;
135 static constexpr
int MPI_UINT32_T = 0x4c00043d;
136 static constexpr
int MPI_UINT64_T = 0x4c00083e;
137 static constexpr
int MPI_MAX = 0x58000001;
138 static constexpr
int MPI_MIN = 0x58000002;
139 static constexpr
int MPI_SUM = 0x58000003;
140 static constexpr
int MPI_PROD = 0x58000004;
141 static constexpr
int MPI_LAND = 0x58000005;
142 static constexpr
int MPI_BAND = 0x58000006;
143 static constexpr
int MPI_LOR = 0x58000007;
144 static constexpr
int MPI_BOR = 0x58000008;
145 static constexpr
int MPI_LXOR = 0x58000009;
146 static constexpr
int MPI_BXOR = 0x5800000a;
147 static constexpr
int MPI_MINLOC = 0x5800000b;
148 static constexpr
int MPI_MAXLOC = 0x5800000c;
149 static constexpr
int MPI_REPLACE = 0x5800000d;
150 static constexpr
int MPI_NO_OP = 0x5800000e;
153 using MPIImplTraits::MPIImplTraits;
155 ~MPICHImplTraits()
override =
default;
159 static constexpr
int MPI_COMM_WORLD = 0x44000000;
160 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(),
165 Value comm)
override {
166 return LLVM::TruncOp::create(rewriter, loc, rewriter.
getI32Type(), comm);
169 intptr_t getStatusIgnore()
override {
return 1; }
171 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
174 Type type)
override {
178 else if (type.
isF64())
183 mtype = MPI_UINT64_T;
187 mtype = MPI_UINT32_T;
191 mtype = MPI_UINT16_T;
197 assert(
false &&
"unsupported type");
198 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI32Type(),
203 mpi::MPI_ReductionOpEnum opAttr)
override {
204 int32_t op = MPI_NO_OP;
206 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
209 case mpi::MPI_ReductionOpEnum::MPI_MAX:
212 case mpi::MPI_ReductionOpEnum::MPI_MIN:
215 case mpi::MPI_ReductionOpEnum::MPI_SUM:
218 case mpi::MPI_ReductionOpEnum::MPI_PROD:
221 case mpi::MPI_ReductionOpEnum::MPI_LAND:
224 case mpi::MPI_ReductionOpEnum::MPI_BAND:
227 case mpi::MPI_ReductionOpEnum::MPI_LOR:
230 case mpi::MPI_ReductionOpEnum::MPI_BOR:
233 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
236 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
239 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
242 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
245 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
249 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI32Type(), op);
256 class OMPIImplTraits :
public MPIImplTraits {
257 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
260 LLVM::LLVMStructType type) {
262 return getOrDefineGlobal<LLVM::GlobalOp>(
263 getModuleOp(), loc, rewriter, name, type,
false,
264 LLVM::Linkage::External, name,
269 using MPIImplTraits::MPIImplTraits;
271 ~OMPIImplTraits()
override =
default;
278 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
279 StringRef name =
"ompi_mpi_comm_world";
282 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
285 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
288 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.
getI64Type(), comm);
292 Value comm)
override {
293 return LLVM::IntToPtrOp::create(
297 intptr_t getStatusIgnore()
override {
return 0; }
299 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
302 Type type)
override {
305 mtype =
"ompi_mpi_float";
306 else if (type.
isF64())
307 mtype =
"ompi_mpi_double";
309 mtype =
"ompi_mpi_int64_t";
311 mtype =
"ompi_mpi_uint64_t";
313 mtype =
"ompi_mpi_int32_t";
315 mtype =
"ompi_mpi_uint32_t";
317 mtype =
"ompi_mpi_int16_t";
319 mtype =
"ompi_mpi_uint16_t";
321 mtype =
"ompi_mpi_int8_t";
323 mtype =
"ompi_mpi_uint8_t";
325 assert(
false &&
"unsupported type");
330 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
332 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
334 return LLVM::AddressOfOp::create(rewriter, loc,
340 mpi::MPI_ReductionOpEnum opAttr)
override {
343 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
344 op =
"ompi_mpi_no_op";
346 case mpi::MPI_ReductionOpEnum::MPI_MAX:
349 case mpi::MPI_ReductionOpEnum::MPI_MIN:
352 case mpi::MPI_ReductionOpEnum::MPI_SUM:
355 case mpi::MPI_ReductionOpEnum::MPI_PROD:
356 op =
"ompi_mpi_prod";
358 case mpi::MPI_ReductionOpEnum::MPI_LAND:
359 op =
"ompi_mpi_land";
361 case mpi::MPI_ReductionOpEnum::MPI_BAND:
362 op =
"ompi_mpi_band";
364 case mpi::MPI_ReductionOpEnum::MPI_LOR:
367 case mpi::MPI_ReductionOpEnum::MPI_BOR:
370 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
371 op =
"ompi_mpi_lxor";
373 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
374 op =
"ompi_mpi_bxor";
376 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
377 op =
"ompi_mpi_minloc";
379 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
380 op =
"ompi_mpi_maxloc";
382 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
383 op =
"ompi_mpi_replace";
389 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
391 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
393 return LLVM::AddressOfOp::create(rewriter, loc,
400 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
402 return std::make_unique<MPICHImplTraits>(moduleOp);
403 auto strAttr = dyn_cast<StringAttr>(attr.value());
404 if (strAttr && strAttr.getValue() ==
"OpenMPI")
405 return std::make_unique<OMPIImplTraits>(moduleOp);
406 if (!strAttr || strAttr.getValue() !=
"MPICH")
407 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
408 << strAttr.getValue() <<
"), defaulting to MPICH";
409 return std::make_unique<MPICHImplTraits>(moduleOp);
420 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
428 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
429 Value llvmnull = nullPtrOp.getRes();
432 auto moduleOp = op->getParentOfType<ModuleOp>();
438 LLVM::LLVMFuncOp initDecl =
457 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
463 auto moduleOp = op->getParentOfType<ModuleOp>();
469 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
486 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
489 auto moduleOp = op->getParentOfType<ModuleOp>();
492 rewriter.
replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
506 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
509 auto moduleOp = op->getParentOfType<ModuleOp>();
516 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
517 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
519 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
526 "MPI_Comm_split", funcType);
529 LLVM::CallOp::create(rewriter, loc, funcDecl,
531 adaptor.getKey(), outPtr.getRes()});
534 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
535 res = LLVM::SExtOp::create(rewriter, loc, rewriter.
getI64Type(), res);
541 replacements.push_back(callOp.getResult());
544 replacements.push_back(res);
559 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
570 auto moduleOp = op->getParentOfType<ModuleOp>();
574 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
581 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
584 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
585 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
586 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
591 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
597 replacements.push_back(callOp.getResult());
600 replacements.push_back(loadedRank.getRes());
615 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
621 Type elemType = op.getRef().getType().getElementType();
627 auto moduleOp = op->getParentOfType<ModuleOp>();
630 auto [dataPtr, size] =
631 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
633 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
634 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
639 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
641 LLVM::LLVMFuncOp funcDecl =
645 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
648 adaptor.getTag(), comm});
650 rewriter.
replaceOp(op, funcCall.getResult());
666 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
673 Type elemType = op.getRef().getType().getElementType();
679 auto moduleOp = op->getParentOfType<ModuleOp>();
682 auto [dataPtr, size] =
683 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
685 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
686 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
687 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
688 mpiTraits->getStatusIgnore());
690 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
696 i32, comm.
getType(), ptrType});
698 LLVM::LLVMFuncOp funcDecl =
702 auto funcCall = LLVM::CallOp::create(
703 rewriter, loc, funcDecl,
704 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
705 adaptor.getTag(), comm, statusIgnore});
707 rewriter.
replaceOp(op, funcCall.getResult());
723 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
729 Type elemType = op.getSendbuf().getType().getElementType();
733 auto moduleOp = op->getParentOfType<ModuleOp>();
735 auto [sendPtr, sendSize] =
736 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
737 auto [recvPtr, recvSize] =
738 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
741 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
742 sendPtr = LLVM::ConstantOp::create(
744 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
745 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
748 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
749 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
750 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
758 LLVM::LLVMFuncOp funcDecl =
762 auto funcCall = LLVM::CallOp::create(
763 rewriter, loc, funcDecl,
764 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
767 rewriter.
replaceOp(op, funcCall.getResult());
784 void populateConvertToLLVMConversionPatterns(
804 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
805 FinalizeOpLowering, InitOpLowering, SendOpLowering,
806 RecvOpLowering, AllReduceOpLowering>(converter);
811 dialect->addInterfaces<FuncToLLVMDialectInterface>();
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.