28 template <
typename Op,
typename... Args>
29 static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
33 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
34 ConversionPatternRewriter::InsertionGuard guard(rewriter);
36 ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
45 LLVM::LLVMFunctionType type) {
46 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
47 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
61 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
63 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
65 virtual ~MPIImplTraits() =
default;
67 ModuleOp &getModuleOp() {
return moduleOp; }
74 virtual intptr_t getStatusIgnore() = 0;
86 class MPICHImplTraits :
public MPIImplTraits {
87 static constexpr
int MPI_FLOAT = 0x4c00040a;
88 static constexpr
int MPI_DOUBLE = 0x4c00080b;
89 static constexpr
int MPI_INT8_T = 0x4c000137;
90 static constexpr
int MPI_INT16_T = 0x4c000238;
91 static constexpr
int MPI_INT32_T = 0x4c000439;
92 static constexpr
int MPI_INT64_T = 0x4c00083a;
93 static constexpr
int MPI_UINT8_T = 0x4c00013b;
94 static constexpr
int MPI_UINT16_T = 0x4c00023c;
95 static constexpr
int MPI_UINT32_T = 0x4c00043d;
96 static constexpr
int MPI_UINT64_T = 0x4c00083e;
99 using MPIImplTraits::MPIImplTraits;
101 ~MPICHImplTraits()
override =
default;
105 static constexpr
int MPI_COMM_WORLD = 0x44000000;
110 intptr_t getStatusIgnore()
override {
return 1; }
113 Type type)
override {
117 else if (type.
isF64())
122 mtype = MPI_UINT64_T;
126 mtype = MPI_UINT32_T;
130 mtype = MPI_UINT16_T;
136 assert(
false &&
"unsupported type");
137 return rewriter.
create<LLVM::ConstantOp>(loc, rewriter.
getI32Type(), mtype);
144 class OMPIImplTraits :
public MPIImplTraits {
145 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
148 LLVM::LLVMStructType type) {
150 return getOrDefineGlobal<LLVM::GlobalOp>(
151 getModuleOp(), loc, rewriter, name, type,
false,
152 LLVM::Linkage::External, name,
157 using MPIImplTraits::MPIImplTraits;
159 ~OMPIImplTraits()
override =
default;
166 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
167 StringRef name =
"ompi_mpi_comm_world";
170 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
173 return rewriter.
create<LLVM::AddressOfOp>(
178 intptr_t getStatusIgnore()
override {
return 0; }
181 Type type)
override {
184 mtype =
"ompi_mpi_float";
185 else if (type.
isF64())
186 mtype =
"ompi_mpi_double";
188 mtype =
"ompi_mpi_int64_t";
190 mtype =
"ompi_mpi_uint64_t";
192 mtype =
"ompi_mpi_int32_t";
194 mtype =
"ompi_mpi_uint32_t";
196 mtype =
"ompi_mpi_int16_t";
198 mtype =
"ompi_mpi_uint16_t";
200 mtype =
"ompi_mpi_int8_t";
202 mtype =
"ompi_mpi_uint8_t";
204 assert(
false &&
"unsupported type");
209 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
211 getOrDefineExternalStruct(loc, rewriter, mtype, commStructT);
213 return rewriter.
create<LLVM::AddressOfOp>(
220 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
222 return std::make_unique<MPICHImplTraits>(moduleOp);
223 auto strAttr = dyn_cast<StringAttr>(attr.value());
224 if (strAttr && strAttr.getValue() ==
"OpenMPI")
225 return std::make_unique<OMPIImplTraits>(moduleOp);
226 if (!strAttr || strAttr.getValue() !=
"MPICH")
227 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
228 << strAttr.getValue() <<
"), defaulting to MPICH";
229 return std::make_unique<MPICHImplTraits>(moduleOp);
240 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
248 auto nullPtrOp = rewriter.
create<LLVM::ZeroOp>(loc, ptrType);
249 Value llvmnull = nullPtrOp.getRes();
252 auto moduleOp = op->getParentOfType<ModuleOp>();
258 LLVM::LLVMFuncOp initDecl =
277 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
283 auto moduleOp = op->getParentOfType<ModuleOp>();
289 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
306 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
317 auto moduleOp = op->getParentOfType<ModuleOp>();
321 Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
328 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
331 auto one = rewriter.
create<LLVM::ConstantOp>(loc, i32, 1);
332 auto rankptr = rewriter.
create<LLVM::AllocaOp>(loc, ptrType, i32, one);
333 auto callOp = rewriter.
create<LLVM::CallOp>(
334 loc, initDecl,
ValueRange{commWorld, rankptr.getRes()});
338 rewriter.
create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
344 replacements.push_back(callOp.getResult());
347 replacements.push_back(loadedRank.getRes());
362 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
369 Value memRef = adaptor.getRef();
370 Type elemType = op.getRef().getType().getElementType();
376 auto moduleOp = op->getParentOfType<ModuleOp>();
380 rewriter.
create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
381 Value offset = rewriter.
create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
383 rewriter.
create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
384 Value size = rewriter.
create<LLVM::ExtractValueOp>(loc, memRef,
386 size = rewriter.
create<LLVM::TruncOp>(loc, i32, size);
388 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
389 Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
394 i32, {ptrType, i32, dataType.
getType(), i32, i32, commWorld.
getType()});
396 LLVM::LLVMFuncOp funcDecl =
400 auto funcCall = rewriter.
create<LLVM::CallOp>(
402 ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
405 rewriter.
replaceOp(op, funcCall.getResult());
421 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
428 Value memRef = adaptor.getRef();
429 Type elemType = op.getRef().getType().getElementType();
435 auto moduleOp = op->getParentOfType<ModuleOp>();
439 rewriter.
create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
440 Value offset = rewriter.
create<LLVM::ExtractValueOp>(loc, i64, memRef, 2);
442 rewriter.
create<LLVM::GEPOp>(loc, ptrType, elemType, dataPtr, offset);
443 Value size = rewriter.
create<LLVM::ExtractValueOp>(loc, memRef,
445 size = rewriter.
create<LLVM::TruncOp>(loc, i32, size);
447 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
448 Value commWorld = mpiTraits->getCommWorld(loc, rewriter);
449 Value statusIgnore = rewriter.
create<LLVM::ConstantOp>(
450 loc, i64, mpiTraits->getStatusIgnore());
452 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
458 i32, commWorld.
getType(), ptrType});
460 LLVM::LLVMFuncOp funcDecl =
464 auto funcCall = rewriter.
create<LLVM::CallOp>(
466 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
467 adaptor.getTag(), commWorld, statusIgnore});
469 rewriter.
replaceOp(op, funcCall.getResult());
486 void populateConvertToLLVMConversionPatterns(
500 patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering,
501 SendOpLowering, RecvOpLowering>(converter);
506 dialect->addInterfaces<FuncToLLVMDialectInterface>();
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.