30 template <
typename Op,
typename... Args>
31 static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
35 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52 std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
63 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.
getI32Type(), size);
70 return {resPtr, size};
84 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
88 virtual ~MPIImplTraits() =
default;
90 ModuleOp &getModuleOp() {
return moduleOp; }
105 virtual intptr_t getStatusIgnore() = 0;
108 virtual void *getInPlace() = 0;
119 mpi::MPI_ReductionOpEnum opAttr) = 0;
126 class MPICHImplTraits :
public MPIImplTraits {
127 static constexpr
int MPI_FLOAT = 0x4c00040a;
128 static constexpr
int MPI_DOUBLE = 0x4c00080b;
129 static constexpr
int MPI_INT8_T = 0x4c000137;
130 static constexpr
int MPI_INT16_T = 0x4c000238;
131 static constexpr
int MPI_INT32_T = 0x4c000439;
132 static constexpr
int MPI_INT64_T = 0x4c00083a;
133 static constexpr
int MPI_UINT8_T = 0x4c00013b;
134 static constexpr
int MPI_UINT16_T = 0x4c00023c;
135 static constexpr
int MPI_UINT32_T = 0x4c00043d;
136 static constexpr
int MPI_UINT64_T = 0x4c00083e;
137 static constexpr
int MPI_MAX = 0x58000001;
138 static constexpr
int MPI_MIN = 0x58000002;
139 static constexpr
int MPI_SUM = 0x58000003;
140 static constexpr
int MPI_PROD = 0x58000004;
141 static constexpr
int MPI_LAND = 0x58000005;
142 static constexpr
int MPI_BAND = 0x58000006;
143 static constexpr
int MPI_LOR = 0x58000007;
144 static constexpr
int MPI_BOR = 0x58000008;
145 static constexpr
int MPI_LXOR = 0x58000009;
146 static constexpr
int MPI_BXOR = 0x5800000a;
147 static constexpr
int MPI_MINLOC = 0x5800000b;
148 static constexpr
int MPI_MAXLOC = 0x5800000c;
149 static constexpr
int MPI_REPLACE = 0x5800000d;
150 static constexpr
int MPI_NO_OP = 0x5800000e;
153 using MPIImplTraits::MPIImplTraits;
155 ~MPICHImplTraits()
override =
default;
159 static constexpr
int MPI_COMM_WORLD = 0x44000000;
160 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(),
165 Value comm)
override {
166 return LLVM::TruncOp::create(rewriter, loc, rewriter.
getI32Type(), comm);
169 intptr_t getStatusIgnore()
override {
return 1; }
171 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
174 Type type)
override {
178 else if (type.
isF64())
183 mtype = MPI_UINT64_T;
187 mtype = MPI_UINT32_T;
191 mtype = MPI_UINT16_T;
197 assert(
false &&
"unsupported type");
198 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI32Type(),
203 mpi::MPI_ReductionOpEnum opAttr)
override {
204 int32_t op = MPI_NO_OP;
206 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
209 case mpi::MPI_ReductionOpEnum::MPI_MAX:
212 case mpi::MPI_ReductionOpEnum::MPI_MIN:
215 case mpi::MPI_ReductionOpEnum::MPI_SUM:
218 case mpi::MPI_ReductionOpEnum::MPI_PROD:
221 case mpi::MPI_ReductionOpEnum::MPI_LAND:
224 case mpi::MPI_ReductionOpEnum::MPI_BAND:
227 case mpi::MPI_ReductionOpEnum::MPI_LOR:
230 case mpi::MPI_ReductionOpEnum::MPI_BOR:
233 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
236 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
239 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
242 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
245 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
249 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI32Type(), op);
256 class OMPIImplTraits :
public MPIImplTraits {
257 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
260 LLVM::LLVMStructType type) {
262 return getOrDefineGlobal<LLVM::GlobalOp>(
263 getModuleOp(), loc, rewriter, name, type,
false,
264 LLVM::Linkage::External, name,
269 using MPIImplTraits::MPIImplTraits;
271 ~OMPIImplTraits()
override =
default;
278 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
279 StringRef name =
"ompi_mpi_comm_world";
282 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
285 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
288 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.
getI64Type(), comm);
292 Value comm)
override {
293 return LLVM::IntToPtrOp::create(
297 intptr_t getStatusIgnore()
override {
return 0; }
299 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
302 Type type)
override {
305 mtype =
"ompi_mpi_float";
306 else if (type.
isF64())
307 mtype =
"ompi_mpi_double";
309 mtype =
"ompi_mpi_int64_t";
311 mtype =
"ompi_mpi_uint64_t";
313 mtype =
"ompi_mpi_int32_t";
315 mtype =
"ompi_mpi_uint32_t";
317 mtype =
"ompi_mpi_int16_t";
319 mtype =
"ompi_mpi_uint16_t";
321 mtype =
"ompi_mpi_int8_t";
323 mtype =
"ompi_mpi_uint8_t";
325 assert(
false &&
"unsupported type");
330 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
332 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
334 return LLVM::AddressOfOp::create(rewriter, loc,
340 mpi::MPI_ReductionOpEnum opAttr)
override {
343 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
344 op =
"ompi_mpi_no_op";
346 case mpi::MPI_ReductionOpEnum::MPI_MAX:
349 case mpi::MPI_ReductionOpEnum::MPI_MIN:
352 case mpi::MPI_ReductionOpEnum::MPI_SUM:
355 case mpi::MPI_ReductionOpEnum::MPI_PROD:
356 op =
"ompi_mpi_prod";
358 case mpi::MPI_ReductionOpEnum::MPI_LAND:
359 op =
"ompi_mpi_land";
361 case mpi::MPI_ReductionOpEnum::MPI_BAND:
362 op =
"ompi_mpi_band";
364 case mpi::MPI_ReductionOpEnum::MPI_LOR:
367 case mpi::MPI_ReductionOpEnum::MPI_BOR:
370 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
371 op =
"ompi_mpi_lxor";
373 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
374 op =
"ompi_mpi_bxor";
376 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
377 op =
"ompi_mpi_minloc";
379 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
380 op =
"ompi_mpi_maxloc";
382 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
383 op =
"ompi_mpi_replace";
389 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
391 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
393 return LLVM::AddressOfOp::create(rewriter, loc,
400 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
402 return std::make_unique<MPICHImplTraits>(moduleOp);
403 auto strAttr = dyn_cast<StringAttr>(attr.value());
404 if (strAttr && strAttr.getValue() ==
"OpenMPI")
405 return std::make_unique<OMPIImplTraits>(moduleOp);
406 if (!strAttr || strAttr.getValue() !=
"MPICH")
407 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
408 << (strAttr ? strAttr.getValue() :
"<NULL>")
409 <<
"), defaulting to MPICH";
410 return std::make_unique<MPICHImplTraits>(moduleOp);
421 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
429 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
430 Value llvmnull = nullPtrOp.getRes();
433 auto moduleOp = op->getParentOfType<ModuleOp>();
439 LLVM::LLVMFuncOp initDecl =
458 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
464 auto moduleOp = op->getParentOfType<ModuleOp>();
470 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
487 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
490 auto moduleOp = op->getParentOfType<ModuleOp>();
493 rewriter.
replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
507 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
510 auto moduleOp = op->getParentOfType<ModuleOp>();
517 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
518 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
520 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.
getType(), one);
527 "MPI_Comm_split", funcType);
530 LLVM::CallOp::create(rewriter, loc, funcDecl,
532 adaptor.getKey(), outPtr.getRes()});
535 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
536 res = LLVM::SExtOp::create(rewriter, loc, rewriter.
getI64Type(), res);
542 replacements.push_back(callOp.getResult());
545 replacements.push_back(res);
560 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
571 auto moduleOp = op->getParentOfType<ModuleOp>();
575 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
582 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
585 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
586 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
587 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
592 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
598 replacements.push_back(callOp.getResult());
601 replacements.push_back(loadedRank.getRes());
616 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
622 Type elemType = op.getRef().getType().getElementType();
628 auto moduleOp = op->getParentOfType<ModuleOp>();
631 auto [dataPtr, size] =
632 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
634 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
635 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
640 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
642 LLVM::LLVMFuncOp funcDecl =
646 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
649 adaptor.getTag(), comm});
651 rewriter.
replaceOp(op, funcCall.getResult());
667 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
674 Type elemType = op.getRef().getType().getElementType();
680 auto moduleOp = op->getParentOfType<ModuleOp>();
683 auto [dataPtr, size] =
684 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
686 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
687 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
688 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
689 mpiTraits->getStatusIgnore());
691 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
697 i32, comm.
getType(), ptrType});
699 LLVM::LLVMFuncOp funcDecl =
703 auto funcCall = LLVM::CallOp::create(
704 rewriter, loc, funcDecl,
705 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
706 adaptor.getTag(), comm, statusIgnore});
708 rewriter.
replaceOp(op, funcCall.getResult());
724 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
730 Type elemType = op.getSendbuf().getType().getElementType();
734 auto moduleOp = op->getParentOfType<ModuleOp>();
736 auto [sendPtr, sendSize] =
737 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
738 auto [recvPtr, recvSize] =
739 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
742 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
743 sendPtr = LLVM::ConstantOp::create(
745 reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
746 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
749 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
750 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
751 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
759 LLVM::LLVMFuncOp funcDecl =
763 auto funcCall = LLVM::CallOp::create(
764 rewriter, loc, funcDecl,
765 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
768 rewriter.
replaceOp(op, funcCall.getResult());
785 void populateConvertToLLVMConversionPatterns(
805 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
806 FinalizeOpLowering, InitOpLowering, SendOpLowering,
807 RecvOpLowering, AllReduceOpLowering>(converter);
812 dialect->addInterfaces<FuncToLLVMDialectInterface>();
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.