MLIR 22.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//
10// Copyright (C) by Argonne National Laboratory
11// See COPYRIGHT in top-level directory
12// of MPICH source repository.
13//
14
24#include <memory>
25
26using namespace mlir;
27
28namespace {
29
30template <typename Op, typename... Args>
31static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
32 ConversionPatternRewriter &rewriter, StringRef name,
33 Args &&...args) {
34 Op ret;
35 if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
37 rewriter.setInsertionPointToStart(moduleOp.getBody());
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
39 }
40 return ret;
41}
42
43static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
44 const Location loc,
45 ConversionPatternRewriter &rewriter,
46 StringRef name,
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
50}
51
52std::pair<Value, Value> getRawPtrAndSize(const Location loc,
53 ConversionPatternRewriter &rewriter,
54 Value memRef, Type elType) {
55 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
56 Value dataPtr =
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
59 rewriter.getI64Type(), memRef, 2);
60 Value resPtr =
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
62 Value size;
63 if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
65 ArrayRef<int64_t>{3, 0});
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
67 } else {
68 size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
69 }
70 return {resPtr, size};
71}
72
73/// When lowering the mpi dialect to functions calls certain details
74/// differ between various MPI implementations. This class will provide
75/// these in a generic way, depending on the MPI implementation that got
76/// selected by the DLTI attribute on the module.
77class MPIImplTraits {
78 ModuleOp &moduleOp;
79
80public:
81 /// Instantiate a new MPIImplTraits object according to the DLTI attribute
82 /// on the given module. Default to MPICH if no attribute is present or
83 /// the value is unknown.
84 static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
85
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
87
88 virtual ~MPIImplTraits() = default;
89
90 ModuleOp &getModuleOp() { return moduleOp; }
91
92 /// Gets or creates MPI_COMM_WORLD as a Value.
93 /// Different MPI implementations have different communicator types.
94 /// Using i64 as a portable, intermediate type.
95 /// Appropriate cast needs to take place before calling MPI functions.
96 virtual Value getCommWorld(const Location loc,
97 ConversionPatternRewriter &rewriter) = 0;
98
99 /// Type converter provides i64 type for communicator type.
100 /// Converts to native type, which might be ptr or int or whatever.
101 virtual Value castComm(const Location loc,
102 ConversionPatternRewriter &rewriter, Value comm) = 0;
103
104 /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
105 virtual intptr_t getStatusIgnore() = 0;
106
107 /// Get the MPI_IN_PLACE value (void *).
108 virtual void *getInPlace() = 0;
109
110 /// Gets or creates an MPI datatype as a value which corresponds to the given
111 /// type.
112 virtual Value getDataType(const Location loc,
113 ConversionPatternRewriter &rewriter, Type type) = 0;
114
115 /// Gets or creates an MPI_Op value which corresponds to the given
116 /// enum value.
117 virtual Value getMPIOp(const Location loc,
118 ConversionPatternRewriter &rewriter,
119 mpi::MPI_ReductionOpEnum opAttr) = 0;
120};
121
122//===----------------------------------------------------------------------===//
123// Implementation details for MPICH ABI compatible MPI implementations
124//===----------------------------------------------------------------------===//
125
126class MPICHImplTraits : public MPIImplTraits {
127 static constexpr int MPI_FLOAT = 0x4c00040a;
128 static constexpr int MPI_DOUBLE = 0x4c00080b;
129 static constexpr int MPI_INT8_T = 0x4c000137;
130 static constexpr int MPI_INT16_T = 0x4c000238;
131 static constexpr int MPI_INT32_T = 0x4c000439;
132 static constexpr int MPI_INT64_T = 0x4c00083a;
133 static constexpr int MPI_UINT8_T = 0x4c00013b;
134 static constexpr int MPI_UINT16_T = 0x4c00023c;
135 static constexpr int MPI_UINT32_T = 0x4c00043d;
136 static constexpr int MPI_UINT64_T = 0x4c00083e;
137 static constexpr int MPI_MAX = 0x58000001;
138 static constexpr int MPI_MIN = 0x58000002;
139 static constexpr int MPI_SUM = 0x58000003;
140 static constexpr int MPI_PROD = 0x58000004;
141 static constexpr int MPI_LAND = 0x58000005;
142 static constexpr int MPI_BAND = 0x58000006;
143 static constexpr int MPI_LOR = 0x58000007;
144 static constexpr int MPI_BOR = 0x58000008;
145 static constexpr int MPI_LXOR = 0x58000009;
146 static constexpr int MPI_BXOR = 0x5800000a;
147 static constexpr int MPI_MINLOC = 0x5800000b;
148 static constexpr int MPI_MAXLOC = 0x5800000c;
149 static constexpr int MPI_REPLACE = 0x5800000d;
150 static constexpr int MPI_NO_OP = 0x5800000e;
151
152public:
153 using MPIImplTraits::MPIImplTraits;
154
155 ~MPICHImplTraits() override = default;
156
157 Value getCommWorld(const Location loc,
158 ConversionPatternRewriter &rewriter) override {
159 static constexpr int MPI_COMM_WORLD = 0x44000000;
160 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
161 MPI_COMM_WORLD);
162 }
163
164 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
165 Value comm) override {
166 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
167 }
168
169 intptr_t getStatusIgnore() override { return 1; }
170
171 void *getInPlace() override { return reinterpret_cast<void *>(-1); }
172
173 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
174 Type type) override {
175 int32_t mtype = 0;
176 if (type.isF32())
177 mtype = MPI_FLOAT;
178 else if (type.isF64())
179 mtype = MPI_DOUBLE;
180 else if (type.isInteger(64) && !type.isUnsignedInteger())
181 mtype = MPI_INT64_T;
182 else if (type.isInteger(64))
183 mtype = MPI_UINT64_T;
184 else if (type.isInteger(32) && !type.isUnsignedInteger())
185 mtype = MPI_INT32_T;
186 else if (type.isInteger(32))
187 mtype = MPI_UINT32_T;
188 else if (type.isInteger(16) && !type.isUnsignedInteger())
189 mtype = MPI_INT16_T;
190 else if (type.isInteger(16))
191 mtype = MPI_UINT16_T;
192 else if (type.isInteger(8) && !type.isUnsignedInteger())
193 mtype = MPI_INT8_T;
194 else if (type.isInteger(8))
195 mtype = MPI_UINT8_T;
196 else
197 assert(false && "unsupported type");
198 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
199 mtype);
200 }
201
202 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
203 mpi::MPI_ReductionOpEnum opAttr) override {
204 int32_t op = MPI_NO_OP;
205 switch (opAttr) {
206 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
207 op = MPI_NO_OP;
208 break;
209 case mpi::MPI_ReductionOpEnum::MPI_MAX:
210 op = MPI_MAX;
211 break;
212 case mpi::MPI_ReductionOpEnum::MPI_MIN:
213 op = MPI_MIN;
214 break;
215 case mpi::MPI_ReductionOpEnum::MPI_SUM:
216 op = MPI_SUM;
217 break;
218 case mpi::MPI_ReductionOpEnum::MPI_PROD:
219 op = MPI_PROD;
220 break;
221 case mpi::MPI_ReductionOpEnum::MPI_LAND:
222 op = MPI_LAND;
223 break;
224 case mpi::MPI_ReductionOpEnum::MPI_BAND:
225 op = MPI_BAND;
226 break;
227 case mpi::MPI_ReductionOpEnum::MPI_LOR:
228 op = MPI_LOR;
229 break;
230 case mpi::MPI_ReductionOpEnum::MPI_BOR:
231 op = MPI_BOR;
232 break;
233 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
234 op = MPI_LXOR;
235 break;
236 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
237 op = MPI_BXOR;
238 break;
239 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
240 op = MPI_MINLOC;
241 break;
242 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
243 op = MPI_MAXLOC;
244 break;
245 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
246 op = MPI_REPLACE;
247 break;
248 }
249 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
250 }
251};
252
253//===----------------------------------------------------------------------===//
254// Implementation details for OpenMPI
255//===----------------------------------------------------------------------===//
256class OMPIImplTraits : public MPIImplTraits {
257 LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
258 ConversionPatternRewriter &rewriter,
259 StringRef name,
260 LLVM::LLVMStructType type) {
261
262 return getOrDefineGlobal<LLVM::GlobalOp>(
263 getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
264 LLVM::Linkage::External, name,
265 /*value=*/Attribute(), /*alignment=*/0, 0);
266 }
267
268public:
269 using MPIImplTraits::MPIImplTraits;
270
271 ~OMPIImplTraits() override = default;
272
273 Value getCommWorld(const Location loc,
274 ConversionPatternRewriter &rewriter) override {
275 auto *context = rewriter.getContext();
276 // get external opaque struct pointer type
277 auto commStructT =
278 LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
279 StringRef name = "ompi_mpi_comm_world";
280
281 // make sure global op definition exists
282 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
283
284 // get address of symbol
285 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
286 LLVM::LLVMPointerType::get(context),
287 SymbolRefAttr::get(context, name));
288 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
289 }
290
291 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
292 Value comm) override {
293 return LLVM::IntToPtrOp::create(
294 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
295 }
296
297 intptr_t getStatusIgnore() override { return 0; }
298
299 void *getInPlace() override { return reinterpret_cast<void *>(1); }
300
301 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
302 Type type) override {
303 StringRef mtype;
304 if (type.isF32())
305 mtype = "ompi_mpi_float";
306 else if (type.isF64())
307 mtype = "ompi_mpi_double";
308 else if (type.isInteger(64) && !type.isUnsignedInteger())
309 mtype = "ompi_mpi_int64_t";
310 else if (type.isInteger(64))
311 mtype = "ompi_mpi_uint64_t";
312 else if (type.isInteger(32) && !type.isUnsignedInteger())
313 mtype = "ompi_mpi_int32_t";
314 else if (type.isInteger(32))
315 mtype = "ompi_mpi_uint32_t";
316 else if (type.isInteger(16) && !type.isUnsignedInteger())
317 mtype = "ompi_mpi_int16_t";
318 else if (type.isInteger(16))
319 mtype = "ompi_mpi_uint16_t";
320 else if (type.isInteger(8) && !type.isUnsignedInteger())
321 mtype = "ompi_mpi_int8_t";
322 else if (type.isInteger(8))
323 mtype = "ompi_mpi_uint8_t";
324 else
325 assert(false && "unsupported type");
326
327 auto *context = rewriter.getContext();
328 // get external opaque struct pointer type
329 auto typeStructT =
330 LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
331 // make sure global op definition exists
332 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
333 // get address of symbol
334 return LLVM::AddressOfOp::create(rewriter, loc,
335 LLVM::LLVMPointerType::get(context),
336 SymbolRefAttr::get(context, mtype));
337 }
338
339 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
340 mpi::MPI_ReductionOpEnum opAttr) override {
341 StringRef op;
342 switch (opAttr) {
343 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
344 op = "ompi_mpi_no_op";
345 break;
346 case mpi::MPI_ReductionOpEnum::MPI_MAX:
347 op = "ompi_mpi_max";
348 break;
349 case mpi::MPI_ReductionOpEnum::MPI_MIN:
350 op = "ompi_mpi_min";
351 break;
352 case mpi::MPI_ReductionOpEnum::MPI_SUM:
353 op = "ompi_mpi_sum";
354 break;
355 case mpi::MPI_ReductionOpEnum::MPI_PROD:
356 op = "ompi_mpi_prod";
357 break;
358 case mpi::MPI_ReductionOpEnum::MPI_LAND:
359 op = "ompi_mpi_land";
360 break;
361 case mpi::MPI_ReductionOpEnum::MPI_BAND:
362 op = "ompi_mpi_band";
363 break;
364 case mpi::MPI_ReductionOpEnum::MPI_LOR:
365 op = "ompi_mpi_lor";
366 break;
367 case mpi::MPI_ReductionOpEnum::MPI_BOR:
368 op = "ompi_mpi_bor";
369 break;
370 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
371 op = "ompi_mpi_lxor";
372 break;
373 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
374 op = "ompi_mpi_bxor";
375 break;
376 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
377 op = "ompi_mpi_minloc";
378 break;
379 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
380 op = "ompi_mpi_maxloc";
381 break;
382 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
383 op = "ompi_mpi_replace";
384 break;
385 }
386 auto *context = rewriter.getContext();
387 // get external opaque struct pointer type
388 auto opStructT =
389 LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
390 // make sure global op definition exists
391 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
392 // get address of symbol
393 return LLVM::AddressOfOp::create(rewriter, loc,
394 LLVM::LLVMPointerType::get(context),
395 SymbolRefAttr::get(context, op));
396 }
397};
398
399std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
400 auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
401 if (failed(attr))
402 return std::make_unique<MPICHImplTraits>(moduleOp);
403 auto strAttr = dyn_cast<StringAttr>(attr.value());
404 if (strAttr && strAttr.getValue() == "OpenMPI")
405 return std::make_unique<OMPIImplTraits>(moduleOp);
406 if (!strAttr || strAttr.getValue() != "MPICH")
407 moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
408 << (strAttr ? strAttr.getValue() : "<NULL>")
409 << "), defaulting to MPICH";
410 return std::make_unique<MPICHImplTraits>(moduleOp);
411}
412
413//===----------------------------------------------------------------------===//
414// InitOpLowering
415//===----------------------------------------------------------------------===//
416
417struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
419
420 LogicalResult
421 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter) const override {
423 Location loc = op.getLoc();
424
425 // ptrType `!llvm.ptr`
426 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
427
428 // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
429 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
430 Value llvmnull = nullPtrOp.getRes();
431
432 // grab a reference to the global module op:
433 auto moduleOp = op->getParentOfType<ModuleOp>();
434
435 // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
436 auto initFuncType =
437 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
438 // get or create function declaration:
439 LLVM::LLVMFuncOp initDecl =
440 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
441
442 // replace init with function call
443 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
444 ValueRange{llvmnull, llvmnull});
445
446 return success();
447 }
448};
449
450//===----------------------------------------------------------------------===//
451// FinalizeOpLowering
452//===----------------------------------------------------------------------===//
453
454struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
456
457 LogicalResult
458 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
459 ConversionPatternRewriter &rewriter) const override {
460 // get loc
461 Location loc = op.getLoc();
462
463 // grab a reference to the global module op:
464 auto moduleOp = op->getParentOfType<ModuleOp>();
465
466 // LLVM Function type representing `i32 MPI_Finalize()`
467 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
468 // get or create function declaration:
469 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
470 moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
471
472 // replace init with function call
473 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
474
475 return success();
476 }
477};
478
479//===----------------------------------------------------------------------===//
480// CommWorldOpLowering
481//===----------------------------------------------------------------------===//
482
483struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
485
486 LogicalResult
487 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter) const override {
489 // grab a reference to the global module op:
490 auto moduleOp = op->getParentOfType<ModuleOp>();
491 auto mpiTraits = MPIImplTraits::get(moduleOp);
492 // get MPI_COMM_WORLD
493 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
494
495 return success();
496 }
497};
498
499//===----------------------------------------------------------------------===//
500// CommSplitOpLowering
501//===----------------------------------------------------------------------===//
502
503struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
505
506 LogicalResult
507 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 // grab a reference to the global module op:
510 auto moduleOp = op->getParentOfType<ModuleOp>();
511 auto mpiTraits = MPIImplTraits::get(moduleOp);
512 Type i32 = rewriter.getI32Type();
513 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
514 Location loc = op.getLoc();
515
516 // get communicator
517 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
518 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
519 auto outPtr =
520 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
521
522 // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
523 auto funcType =
524 LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
525 // get or create function declaration:
526 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
527 "MPI_Comm_split", funcType);
528
529 auto callOp =
530 LLVM::CallOp::create(rewriter, loc, funcDecl,
531 ValueRange{comm, adaptor.getColor(),
532 adaptor.getKey(), outPtr.getRes()});
533
534 // load the communicator into a register
535 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
536 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
537
538 // if retval is checked, replace uses of retval with the results from the
539 // call op
540 SmallVector<Value> replacements;
541 if (op.getRetval())
542 replacements.push_back(callOp.getResult());
543
544 // replace op
545 replacements.push_back(res);
546 rewriter.replaceOp(op, replacements);
547
548 return success();
549 }
550};
551
552//===----------------------------------------------------------------------===//
553// CommRankOpLowering
554//===----------------------------------------------------------------------===//
555
556struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
558
559 LogicalResult
560 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter) const override {
562 // get some helper vars
563 Location loc = op.getLoc();
564 MLIRContext *context = rewriter.getContext();
565 Type i32 = rewriter.getI32Type();
566
567 // ptrType `!llvm.ptr`
568 Type ptrType = LLVM::LLVMPointerType::get(context);
569
570 // grab a reference to the global module op:
571 auto moduleOp = op->getParentOfType<ModuleOp>();
572
573 auto mpiTraits = MPIImplTraits::get(moduleOp);
574 // get communicator
575 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
576
577 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
578 auto rankFuncType =
579 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
580 // get or create function declaration:
581 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
582 moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
583
584 // replace with function call
585 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
586 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
587 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
588 ValueRange{comm, rankptr.getRes()});
589
590 // load the rank into a register
591 auto loadedRank =
592 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
593
594 // if retval is checked, replace uses of retval with the results from the
595 // call op
596 SmallVector<Value> replacements;
597 if (op.getRetval())
598 replacements.push_back(callOp.getResult());
599
600 // replace all uses, then erase op
601 replacements.push_back(loadedRank.getRes());
602 rewriter.replaceOp(op, replacements);
603
604 return success();
605 }
606};
607
608//===----------------------------------------------------------------------===//
609// SendOpLowering
610//===----------------------------------------------------------------------===//
611
612struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
614
615 LogicalResult
616 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
617 ConversionPatternRewriter &rewriter) const override {
618 // get some helper vars
619 Location loc = op.getLoc();
620 MLIRContext *context = rewriter.getContext();
621 Type i32 = rewriter.getI32Type();
622 Type elemType = op.getRef().getType().getElementType();
623
624 // ptrType `!llvm.ptr`
625 Type ptrType = LLVM::LLVMPointerType::get(context);
626
627 // grab a reference to the global module op:
628 auto moduleOp = op->getParentOfType<ModuleOp>();
629
630 // get MPI_COMM_WORLD, dataType and pointer
631 auto [dataPtr, size] =
632 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
633 auto mpiTraits = MPIImplTraits::get(moduleOp);
634 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
635 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
636
637 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
638 // tag, comm)`
639 auto funcType = LLVM::LLVMFunctionType::get(
640 i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
641 // get or create function declaration:
642 LLVM::LLVMFuncOp funcDecl =
643 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
644
645 // replace op with function call
646 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
647 ValueRange{dataPtr, size, dataType,
648 adaptor.getDest(),
649 adaptor.getTag(), comm});
650 if (op.getRetval())
651 rewriter.replaceOp(op, funcCall.getResult());
652 else
653 rewriter.eraseOp(op);
654
655 return success();
656 }
657};
658
659//===----------------------------------------------------------------------===//
660// RecvOpLowering
661//===----------------------------------------------------------------------===//
662
663struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
665
666 LogicalResult
667 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const override {
669 // get some helper vars
670 Location loc = op.getLoc();
671 MLIRContext *context = rewriter.getContext();
672 Type i32 = rewriter.getI32Type();
673 Type i64 = rewriter.getI64Type();
674 Type elemType = op.getRef().getType().getElementType();
675
676 // ptrType `!llvm.ptr`
677 Type ptrType = LLVM::LLVMPointerType::get(context);
678
679 // grab a reference to the global module op:
680 auto moduleOp = op->getParentOfType<ModuleOp>();
681
682 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
683 auto [dataPtr, size] =
684 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
685 auto mpiTraits = MPIImplTraits::get(moduleOp);
686 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
687 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
688 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
689 mpiTraits->getStatusIgnore());
690 statusIgnore =
691 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
692
693 // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
694 // tag, comm)`
695 auto funcType =
696 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
697 i32, comm.getType(), ptrType});
698 // get or create function declaration:
699 LLVM::LLVMFuncOp funcDecl =
700 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
701
702 // replace op with function call
703 auto funcCall = LLVM::CallOp::create(
704 rewriter, loc, funcDecl,
705 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
706 adaptor.getTag(), comm, statusIgnore});
707 if (op.getRetval())
708 rewriter.replaceOp(op, funcCall.getResult());
709 else
710 rewriter.eraseOp(op);
711
712 return success();
713 }
714};
715
716//===----------------------------------------------------------------------===//
717// AllReduceOpLowering
718//===----------------------------------------------------------------------===//
719
720struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
722
723 LogicalResult
724 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
725 ConversionPatternRewriter &rewriter) const override {
726 Location loc = op.getLoc();
727 MLIRContext *context = rewriter.getContext();
728 Type i32 = rewriter.getI32Type();
729 Type i64 = rewriter.getI64Type();
730 Type elemType = op.getSendbuf().getType().getElementType();
731
732 // ptrType `!llvm.ptr`
733 Type ptrType = LLVM::LLVMPointerType::get(context);
734 auto moduleOp = op->getParentOfType<ModuleOp>();
735 auto mpiTraits = MPIImplTraits::get(moduleOp);
736 auto [sendPtr, sendSize] =
737 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
738 auto [recvPtr, recvSize] =
739 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
740
741 // If input and output are the same, request in-place operation.
742 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
743 sendPtr = LLVM::ConstantOp::create(
744 rewriter, loc, i64,
745 reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
746 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
747 }
748
749 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
750 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
751 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
752
753 // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
754 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
755 auto funcType = LLVM::LLVMFunctionType::get(
756 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
757 commWorld.getType()});
758 // get or create function declaration:
759 LLVM::LLVMFuncOp funcDecl =
760 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
761
762 // replace op with function call
763 auto funcCall = LLVM::CallOp::create(
764 rewriter, loc, funcDecl,
765 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
766
767 if (op.getRetval())
768 rewriter.replaceOp(op, funcCall.getResult());
769 else
770 rewriter.eraseOp(op);
771
772 return success();
773 }
774};
775
776//===----------------------------------------------------------------------===//
777// ConvertToLLVMPatternInterface implementation
778//===----------------------------------------------------------------------===//
779
780/// Implement the interface to convert Func to LLVM.
781struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
783 /// Hook for derived dialect interface to provide conversion patterns
784 /// and mark dialect legal for the conversion target.
785 void populateConvertToLLVMConversionPatterns(
786 ConversionTarget &target, LLVMTypeConverter &typeConverter,
787 RewritePatternSet &patterns) const final {
789 }
790};
791} // namespace
792
793//===----------------------------------------------------------------------===//
794// Pattern Population
795//===----------------------------------------------------------------------===//
796
799 // Using i64 as a portable, intermediate type for !mpi.comm.
800 // It would be nicer to somehow get the right type directly, but TLDI is not
801 // available here.
802 converter.addConversion([](mpi::CommType type) {
803 return IntegerType::get(type.getContext(), 64);
804 });
805 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
806 FinalizeOpLowering, InitOpLowering, SendOpLowering,
807 RecvOpLowering, AllReduceOpLowering>(converter);
808}
809
811 registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
812 dialect->addInterfaces<FuncToLLVMDialectInterface>();
813 });
814}
return success()
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition DLTI.cpp:537
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...