30 template <
typename Op,
typename... Args>
31 static Op getOrDefineGlobal(ModuleOp &moduleOp,
const Location loc,
35 if (!(ret = moduleOp.lookupSymbol<
Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
38 ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52 std::pair<Value, Value> getRawPtrAndSize(
const Location loc,
57 rewriter.
create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
58 Value offset = rewriter.
create<LLVM::ExtractValueOp>(
61 rewriter.
create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
63 if (cast<LLVM::LLVMStructType>(memRef.
getType()).getBody().size() > 3) {
64 size = rewriter.
create<LLVM::ExtractValueOp>(loc, memRef,
68 size = rewriter.
create<arith::ConstantIntOp>(loc, 1, 32);
70 return {resPtr, size};
84 static std::unique_ptr<MPIImplTraits>
get(ModuleOp &moduleOp);
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
88 virtual ~MPIImplTraits() =
default;
90 ModuleOp &getModuleOp() {
return moduleOp; }
105 virtual intptr_t getStatusIgnore() = 0;
108 virtual void *getInPlace() = 0;
119 mpi::MPI_OpClassEnum opAttr) = 0;
126 class MPICHImplTraits :
public MPIImplTraits {
127 static constexpr
int MPI_FLOAT = 0x4c00040a;
128 static constexpr
int MPI_DOUBLE = 0x4c00080b;
129 static constexpr
int MPI_INT8_T = 0x4c000137;
130 static constexpr
int MPI_INT16_T = 0x4c000238;
131 static constexpr
int MPI_INT32_T = 0x4c000439;
132 static constexpr
int MPI_INT64_T = 0x4c00083a;
133 static constexpr
int MPI_UINT8_T = 0x4c00013b;
134 static constexpr
int MPI_UINT16_T = 0x4c00023c;
135 static constexpr
int MPI_UINT32_T = 0x4c00043d;
136 static constexpr
int MPI_UINT64_T = 0x4c00083e;
137 static constexpr
int MPI_MAX = 0x58000001;
138 static constexpr
int MPI_MIN = 0x58000002;
139 static constexpr
int MPI_SUM = 0x58000003;
140 static constexpr
int MPI_PROD = 0x58000004;
141 static constexpr
int MPI_LAND = 0x58000005;
142 static constexpr
int MPI_BAND = 0x58000006;
143 static constexpr
int MPI_LOR = 0x58000007;
144 static constexpr
int MPI_BOR = 0x58000008;
145 static constexpr
int MPI_LXOR = 0x58000009;
146 static constexpr
int MPI_BXOR = 0x5800000a;
147 static constexpr
int MPI_MINLOC = 0x5800000b;
148 static constexpr
int MPI_MAXLOC = 0x5800000c;
149 static constexpr
int MPI_REPLACE = 0x5800000d;
150 static constexpr
int MPI_NO_OP = 0x5800000e;
153 using MPIImplTraits::MPIImplTraits;
155 ~MPICHImplTraits()
override =
default;
159 static constexpr
int MPI_COMM_WORLD = 0x44000000;
165 Value comm)
override {
169 intptr_t getStatusIgnore()
override {
return 1; }
171 void *getInPlace()
override {
return reinterpret_cast<void *
>(-1); }
174 Type type)
override {
178 else if (type.
isF64())
183 mtype = MPI_UINT64_T;
187 mtype = MPI_UINT32_T;
191 mtype = MPI_UINT16_T;
197 assert(
false &&
"unsupported type");
198 return rewriter.
create<LLVM::ConstantOp>(loc, rewriter.
getI32Type(), mtype);
202 mpi::MPI_OpClassEnum opAttr)
override {
203 int32_t op = MPI_NO_OP;
205 case mpi::MPI_OpClassEnum::MPI_OP_NULL:
208 case mpi::MPI_OpClassEnum::MPI_MAX:
211 case mpi::MPI_OpClassEnum::MPI_MIN:
214 case mpi::MPI_OpClassEnum::MPI_SUM:
217 case mpi::MPI_OpClassEnum::MPI_PROD:
220 case mpi::MPI_OpClassEnum::MPI_LAND:
223 case mpi::MPI_OpClassEnum::MPI_BAND:
226 case mpi::MPI_OpClassEnum::MPI_LOR:
229 case mpi::MPI_OpClassEnum::MPI_BOR:
232 case mpi::MPI_OpClassEnum::MPI_LXOR:
235 case mpi::MPI_OpClassEnum::MPI_BXOR:
238 case mpi::MPI_OpClassEnum::MPI_MINLOC:
241 case mpi::MPI_OpClassEnum::MPI_MAXLOC:
244 case mpi::MPI_OpClassEnum::MPI_REPLACE:
255 class OMPIImplTraits :
public MPIImplTraits {
256 LLVM::GlobalOp getOrDefineExternalStruct(
const Location loc,
259 LLVM::LLVMStructType type) {
261 return getOrDefineGlobal<LLVM::GlobalOp>(
262 getModuleOp(), loc, rewriter, name, type,
false,
263 LLVM::Linkage::External, name,
268 using MPIImplTraits::MPIImplTraits;
270 ~OMPIImplTraits()
override =
default;
277 LLVM::LLVMStructType::getOpaque(
"ompi_communicator_t", context);
278 StringRef name =
"ompi_mpi_comm_world";
281 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
284 auto comm = rewriter.
create<LLVM::AddressOfOp>(
291 Value comm)
override {
292 return rewriter.
create<LLVM::IntToPtrOp>(
296 intptr_t getStatusIgnore()
override {
return 0; }
298 void *getInPlace()
override {
return reinterpret_cast<void *
>(1); }
301 Type type)
override {
304 mtype =
"ompi_mpi_float";
305 else if (type.
isF64())
306 mtype =
"ompi_mpi_double";
308 mtype =
"ompi_mpi_int64_t";
310 mtype =
"ompi_mpi_uint64_t";
312 mtype =
"ompi_mpi_int32_t";
314 mtype =
"ompi_mpi_uint32_t";
316 mtype =
"ompi_mpi_int16_t";
318 mtype =
"ompi_mpi_uint16_t";
320 mtype =
"ompi_mpi_int8_t";
322 mtype =
"ompi_mpi_uint8_t";
324 assert(
false &&
"unsupported type");
329 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_datatype_t", context);
331 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
333 return rewriter.
create<LLVM::AddressOfOp>(
339 mpi::MPI_OpClassEnum opAttr)
override {
342 case mpi::MPI_OpClassEnum::MPI_OP_NULL:
343 op =
"ompi_mpi_no_op";
345 case mpi::MPI_OpClassEnum::MPI_MAX:
348 case mpi::MPI_OpClassEnum::MPI_MIN:
351 case mpi::MPI_OpClassEnum::MPI_SUM:
354 case mpi::MPI_OpClassEnum::MPI_PROD:
355 op =
"ompi_mpi_prod";
357 case mpi::MPI_OpClassEnum::MPI_LAND:
358 op =
"ompi_mpi_land";
360 case mpi::MPI_OpClassEnum::MPI_BAND:
361 op =
"ompi_mpi_band";
363 case mpi::MPI_OpClassEnum::MPI_LOR:
366 case mpi::MPI_OpClassEnum::MPI_BOR:
369 case mpi::MPI_OpClassEnum::MPI_LXOR:
370 op =
"ompi_mpi_lxor";
372 case mpi::MPI_OpClassEnum::MPI_BXOR:
373 op =
"ompi_mpi_bxor";
375 case mpi::MPI_OpClassEnum::MPI_MINLOC:
376 op =
"ompi_mpi_minloc";
378 case mpi::MPI_OpClassEnum::MPI_MAXLOC:
379 op =
"ompi_mpi_maxloc";
381 case mpi::MPI_OpClassEnum::MPI_REPLACE:
382 op =
"ompi_mpi_replace";
388 LLVM::LLVMStructType::getOpaque(
"ompi_predefined_op_t", context);
390 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
392 return rewriter.
create<LLVM::AddressOfOp>(
399 auto attr =
dlti::query(*&moduleOp, {
"MPI:Implementation"},
true);
401 return std::make_unique<MPICHImplTraits>(moduleOp);
402 auto strAttr = dyn_cast<StringAttr>(attr.value());
403 if (strAttr && strAttr.getValue() ==
"OpenMPI")
404 return std::make_unique<OMPIImplTraits>(moduleOp);
405 if (!strAttr || strAttr.getValue() !=
"MPICH")
406 moduleOp.emitWarning() <<
"Unknown \"MPI:Implementation\" value in DLTI ("
407 << strAttr.getValue() <<
"), defaulting to MPICH";
408 return std::make_unique<MPICHImplTraits>(moduleOp);
419 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
427 auto nullPtrOp = rewriter.
create<LLVM::ZeroOp>(loc, ptrType);
428 Value llvmnull = nullPtrOp.getRes();
431 auto moduleOp = op->getParentOfType<ModuleOp>();
437 LLVM::LLVMFuncOp initDecl =
456 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
462 auto moduleOp = op->getParentOfType<ModuleOp>();
468 moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);
485 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
488 auto moduleOp = op->getParentOfType<ModuleOp>();
491 rewriter.
replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
505 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
508 auto moduleOp = op->getParentOfType<ModuleOp>();
515 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
516 auto one = rewriter.
create<LLVM::ConstantOp>(loc, i32, 1);
518 rewriter.
create<LLVM::AllocaOp>(loc, ptrType, comm.
getType(), one);
525 "MPI_Comm_split", funcType);
527 auto callOp = rewriter.
create<LLVM::CallOp>(
529 ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
533 Value res = rewriter.
create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
540 replacements.push_back(callOp.getResult());
543 replacements.push_back(res);
558 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
569 auto moduleOp = op->getParentOfType<ModuleOp>();
573 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
580 moduleOp, loc, rewriter,
"MPI_Comm_rank", rankFuncType);
583 auto one = rewriter.
create<LLVM::ConstantOp>(loc, i32, 1);
584 auto rankptr = rewriter.
create<LLVM::AllocaOp>(loc, ptrType, i32, one);
585 auto callOp = rewriter.
create<LLVM::CallOp>(
586 loc, initDecl,
ValueRange{comm, rankptr.getRes()});
590 rewriter.
create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
596 replacements.push_back(callOp.getResult());
599 replacements.push_back(loadedRank.getRes());
614 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
620 Type elemType = op.getRef().getType().getElementType();
626 auto moduleOp = op->getParentOfType<ModuleOp>();
629 auto [dataPtr, size] =
630 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
632 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
633 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
638 i32, {ptrType, i32, dataType.
getType(), i32, i32, comm.
getType()});
640 LLVM::LLVMFuncOp funcDecl =
644 auto funcCall = rewriter.
create<LLVM::CallOp>(
646 ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
649 rewriter.
replaceOp(op, funcCall.getResult());
665 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
672 Type elemType = op.getRef().getType().getElementType();
678 auto moduleOp = op->getParentOfType<ModuleOp>();
681 auto [dataPtr, size] =
682 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
684 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
685 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
686 Value statusIgnore = rewriter.
create<LLVM::ConstantOp>(
687 loc, i64, mpiTraits->getStatusIgnore());
689 rewriter.
create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
695 i32, comm.
getType(), ptrType});
697 LLVM::LLVMFuncOp funcDecl =
701 auto funcCall = rewriter.
create<LLVM::CallOp>(
703 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
704 adaptor.getTag(), comm, statusIgnore});
706 rewriter.
replaceOp(op, funcCall.getResult());
722 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
728 Type elemType = op.getSendbuf().getType().getElementType();
732 auto moduleOp = op->getParentOfType<ModuleOp>();
734 auto [sendPtr, sendSize] =
735 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
736 auto [recvPtr, recvSize] =
737 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
740 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
741 sendPtr = rewriter.
create<LLVM::ConstantOp>(
742 loc, i64,
reinterpret_cast<int64_t
>(mpiTraits->getInPlace()));
743 sendPtr = rewriter.
create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
746 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
747 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
748 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
756 LLVM::LLVMFuncOp funcDecl =
760 auto funcCall = rewriter.
create<LLVM::CallOp>(
762 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
765 rewriter.
replaceOp(op, funcCall.getResult());
782 void populateConvertToLLVMConversionPatterns(
802 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
803 FinalizeOpLowering, InitOpLowering, SendOpLowering,
804 RecvOpLowering, AllReduceOpLowering>(converter);
809 dialect->addInterfaces<FuncToLLVMDialectInterface>();
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry ®istry)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.