MLIR 23.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//
10// Copyright (C) by Argonne National Laboratory
11// See COPYRIGHT in top-level directory
12// of MPICH source repository.
13//
14
25#include <memory>
26
27using namespace mlir;
28
29namespace {
30
31template <typename Op, typename... Args>
32static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
33 ConversionPatternRewriter &rewriter, StringRef name,
34 Args &&...args) {
35 Op ret;
36 if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
37 ConversionPatternRewriter::InsertionGuard guard(rewriter);
38 rewriter.setInsertionPointToStart(moduleOp.getBody());
39 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
40 }
41 return ret;
42}
43
44static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
45 const Location loc,
46 ConversionPatternRewriter &rewriter,
47 StringRef name,
48 LLVM::LLVMFunctionType type) {
49 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
50 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
51}
52
53std::pair<Value, Value> getRawPtrAndSize(const Location loc,
54 ConversionPatternRewriter &rewriter,
55 Value memRef, int64_t rank,
56 Type elType) {
57 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
58 Value dataPtr =
59 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
60 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
61 rewriter.getI64Type(), memRef, 2);
62 Value resPtr =
63 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
64 Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
65 rewriter.getIndexAttr(1));
66 if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
67 for (int64_t i = 0; i < rank; ++i) {
68 Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
69 ArrayRef<int64_t>{3, i});
70 dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
71 size =
72 LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
73 }
74 } else {
75 size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
76 }
77 return {resPtr, size};
78}
79
80/// When lowering the mpi dialect to functions calls certain details
81/// differ between various MPI implementations. This class will provide
82/// these in a generic way, depending on the MPI implementation that got
83/// selected by the DLTI attribute on the module.
84class MPIImplTraits {
85 ModuleOp &moduleOp;
86
87public:
88 /// Instantiate a new MPIImplTraits object according to the DLTI attribute
89 /// on the given module. Default to MPICH if no attribute is present or
90 /// the value is unknown.
91 static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
92
93 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
94
95 virtual ~MPIImplTraits() = default;
96
97 ModuleOp &getModuleOp() { return moduleOp; }
98
99 /// Gets or creates MPI_COMM_WORLD as a Value.
100 /// Different MPI implementations have different communicator types.
101 /// Using i64 as a portable, intermediate type.
102 /// Appropriate cast needs to take place before calling MPI functions.
103 virtual Value getCommWorld(Location loc,
104 ConversionPatternRewriter &rewriter) = 0;
105
106 /// Type converter provides i64 type for communicator type.
107 /// Converts to native type, which might be ptr or int or whatever.
108 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
109 Value comm) = 0;
110
111 /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
112 virtual intptr_t getStatusIgnore() = 0;
113
114 /// Get the MPI_IN_PLACE value (void *).
115 virtual void *getInPlace() = 0;
116
117 /// Gets or creates an MPI datatype as a value which corresponds to the given
118 /// type.
119 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
120 Type type) = 0;
121
122 /// Gets or creates an MPI_Op value which corresponds to the given
123 /// enum value.
124 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
125 mpi::MPI_ReductionOpEnum opAttr) = 0;
126};
127
128//===----------------------------------------------------------------------===//
129// Implementation details for MPICH ABI compatible MPI implementations
130//===----------------------------------------------------------------------===//
131
132class MPICHImplTraits : public MPIImplTraits {
133 static constexpr int MPI_FLOAT = 0x4c00040a;
134 static constexpr int MPI_DOUBLE = 0x4c00080b;
135 static constexpr int MPI_INT8_T = 0x4c000137;
136 static constexpr int MPI_INT16_T = 0x4c000238;
137 static constexpr int MPI_INT32_T = 0x4c000439;
138 static constexpr int MPI_INT64_T = 0x4c00083a;
139 static constexpr int MPI_UINT8_T = 0x4c00013b;
140 static constexpr int MPI_UINT16_T = 0x4c00023c;
141 static constexpr int MPI_UINT32_T = 0x4c00043d;
142 static constexpr int MPI_UINT64_T = 0x4c00083e;
143 static constexpr int MPI_MAX = 0x58000001;
144 static constexpr int MPI_MIN = 0x58000002;
145 static constexpr int MPI_SUM = 0x58000003;
146 static constexpr int MPI_PROD = 0x58000004;
147 static constexpr int MPI_LAND = 0x58000005;
148 static constexpr int MPI_BAND = 0x58000006;
149 static constexpr int MPI_LOR = 0x58000007;
150 static constexpr int MPI_BOR = 0x58000008;
151 static constexpr int MPI_LXOR = 0x58000009;
152 static constexpr int MPI_BXOR = 0x5800000a;
153 static constexpr int MPI_MINLOC = 0x5800000b;
154 static constexpr int MPI_MAXLOC = 0x5800000c;
155 static constexpr int MPI_REPLACE = 0x5800000d;
156 static constexpr int MPI_NO_OP = 0x5800000e;
157
158public:
159 using MPIImplTraits::MPIImplTraits;
160
161 ~MPICHImplTraits() override = default;
162
163 Value getCommWorld(const Location loc,
164 ConversionPatternRewriter &rewriter) override {
165 static constexpr int MPI_COMM_WORLD = 0x44000000;
166 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
167 MPI_COMM_WORLD);
168 }
169
170 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
171 Value comm) override {
172 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
173 }
174
175 intptr_t getStatusIgnore() override { return 1; }
176
177 void *getInPlace() override { return reinterpret_cast<void *>(-1); }
178
179 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
180 Type type) override {
181 int32_t mtype = 0;
182 if (type.isF32())
183 mtype = MPI_FLOAT;
184 else if (type.isF64())
185 mtype = MPI_DOUBLE;
186 else if (type.isInteger(64) && !type.isUnsignedInteger())
187 mtype = MPI_INT64_T;
188 else if (type.isInteger(64))
189 mtype = MPI_UINT64_T;
190 else if (type.isInteger(32) && !type.isUnsignedInteger())
191 mtype = MPI_INT32_T;
192 else if (type.isInteger(32))
193 mtype = MPI_UINT32_T;
194 else if (type.isInteger(16) && !type.isUnsignedInteger())
195 mtype = MPI_INT16_T;
196 else if (type.isInteger(16))
197 mtype = MPI_UINT16_T;
198 else if (type.isInteger(8) && !type.isUnsignedInteger())
199 mtype = MPI_INT8_T;
200 else if (type.isInteger(8))
201 mtype = MPI_UINT8_T;
202 else
203 assert(false && "unsupported type");
204 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
205 mtype);
206 }
207
208 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
209 mpi::MPI_ReductionOpEnum opAttr) override {
210 int32_t op = MPI_NO_OP;
211 switch (opAttr) {
212 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
213 op = MPI_NO_OP;
214 break;
215 case mpi::MPI_ReductionOpEnum::MPI_MAX:
216 op = MPI_MAX;
217 break;
218 case mpi::MPI_ReductionOpEnum::MPI_MIN:
219 op = MPI_MIN;
220 break;
221 case mpi::MPI_ReductionOpEnum::MPI_SUM:
222 op = MPI_SUM;
223 break;
224 case mpi::MPI_ReductionOpEnum::MPI_PROD:
225 op = MPI_PROD;
226 break;
227 case mpi::MPI_ReductionOpEnum::MPI_LAND:
228 op = MPI_LAND;
229 break;
230 case mpi::MPI_ReductionOpEnum::MPI_BAND:
231 op = MPI_BAND;
232 break;
233 case mpi::MPI_ReductionOpEnum::MPI_LOR:
234 op = MPI_LOR;
235 break;
236 case mpi::MPI_ReductionOpEnum::MPI_BOR:
237 op = MPI_BOR;
238 break;
239 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
240 op = MPI_LXOR;
241 break;
242 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
243 op = MPI_BXOR;
244 break;
245 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
246 op = MPI_MINLOC;
247 break;
248 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
249 op = MPI_MAXLOC;
250 break;
251 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
252 op = MPI_REPLACE;
253 break;
254 }
255 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
256 }
257};
258
259//===----------------------------------------------------------------------===//
260// Implementation details for OpenMPI
261//===----------------------------------------------------------------------===//
262class OMPIImplTraits : public MPIImplTraits {
263 LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
264 ConversionPatternRewriter &rewriter,
265 StringRef name,
266 LLVM::LLVMStructType type) {
267
268 return getOrDefineGlobal<LLVM::GlobalOp>(
269 getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
270 LLVM::Linkage::External, name,
271 /*value=*/Attribute(), /*alignment=*/0, 0);
272 }
273
274public:
275 using MPIImplTraits::MPIImplTraits;
276
277 ~OMPIImplTraits() override = default;
278
279 Value getCommWorld(const Location loc,
280 ConversionPatternRewriter &rewriter) override {
281 auto *context = rewriter.getContext();
282 // get external opaque struct pointer type
283 auto commStructT =
284 LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
285 StringRef name = "ompi_mpi_comm_world";
286
287 // make sure global op definition exists
288 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
289
290 // get address of symbol
291 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
292 LLVM::LLVMPointerType::get(context),
293 SymbolRefAttr::get(context, name));
294 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
295 }
296
297 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
298 Value comm) override {
299 return LLVM::IntToPtrOp::create(
300 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
301 }
302
303 intptr_t getStatusIgnore() override { return 0; }
304
305 void *getInPlace() override { return reinterpret_cast<void *>(1); }
306
307 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
308 Type type) override {
309 StringRef mtype;
310 if (type.isF32())
311 mtype = "ompi_mpi_float";
312 else if (type.isF64())
313 mtype = "ompi_mpi_double";
314 else if (type.isInteger(64) && !type.isUnsignedInteger())
315 mtype = "ompi_mpi_int64_t";
316 else if (type.isInteger(64))
317 mtype = "ompi_mpi_uint64_t";
318 else if (type.isInteger(32) && !type.isUnsignedInteger())
319 mtype = "ompi_mpi_int32_t";
320 else if (type.isInteger(32))
321 mtype = "ompi_mpi_uint32_t";
322 else if (type.isInteger(16) && !type.isUnsignedInteger())
323 mtype = "ompi_mpi_int16_t";
324 else if (type.isInteger(16))
325 mtype = "ompi_mpi_uint16_t";
326 else if (type.isInteger(8) && !type.isUnsignedInteger())
327 mtype = "ompi_mpi_int8_t";
328 else if (type.isInteger(8))
329 mtype = "ompi_mpi_uint8_t";
330 else
331 assert(false && "unsupported type");
332
333 auto *context = rewriter.getContext();
334 // get external opaque struct pointer type
335 auto typeStructT =
336 LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
337 // make sure global op definition exists
338 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
339 // get address of symbol
340 return LLVM::AddressOfOp::create(rewriter, loc,
341 LLVM::LLVMPointerType::get(context),
342 SymbolRefAttr::get(context, mtype));
343 }
344
345 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
346 mpi::MPI_ReductionOpEnum opAttr) override {
347 StringRef op;
348 switch (opAttr) {
349 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
350 op = "ompi_mpi_no_op";
351 break;
352 case mpi::MPI_ReductionOpEnum::MPI_MAX:
353 op = "ompi_mpi_max";
354 break;
355 case mpi::MPI_ReductionOpEnum::MPI_MIN:
356 op = "ompi_mpi_min";
357 break;
358 case mpi::MPI_ReductionOpEnum::MPI_SUM:
359 op = "ompi_mpi_sum";
360 break;
361 case mpi::MPI_ReductionOpEnum::MPI_PROD:
362 op = "ompi_mpi_prod";
363 break;
364 case mpi::MPI_ReductionOpEnum::MPI_LAND:
365 op = "ompi_mpi_land";
366 break;
367 case mpi::MPI_ReductionOpEnum::MPI_BAND:
368 op = "ompi_mpi_band";
369 break;
370 case mpi::MPI_ReductionOpEnum::MPI_LOR:
371 op = "ompi_mpi_lor";
372 break;
373 case mpi::MPI_ReductionOpEnum::MPI_BOR:
374 op = "ompi_mpi_bor";
375 break;
376 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
377 op = "ompi_mpi_lxor";
378 break;
379 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
380 op = "ompi_mpi_bxor";
381 break;
382 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
383 op = "ompi_mpi_minloc";
384 break;
385 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
386 op = "ompi_mpi_maxloc";
387 break;
388 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
389 op = "ompi_mpi_replace";
390 break;
391 }
392 auto *context = rewriter.getContext();
393 // get external opaque struct pointer type
394 auto opStructT =
395 LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
396 // make sure global op definition exists
397 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
398 // get address of symbol
399 return LLVM::AddressOfOp::create(rewriter, loc,
400 LLVM::LLVMPointerType::get(context),
401 SymbolRefAttr::get(context, op));
402 }
403};
404
405std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
406 auto attr = dlti::query(moduleOp, {"MPI:Implementation"}, false);
407 if (failed(attr))
408 return std::make_unique<MPICHImplTraits>(moduleOp);
409 auto strAttr = dyn_cast<StringAttr>(attr.value());
410 if (strAttr && strAttr.getValue() == "OpenMPI")
411 return std::make_unique<OMPIImplTraits>(moduleOp);
412 if (!strAttr || strAttr.getValue() != "MPICH")
413 moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
414 << (strAttr ? strAttr.getValue() : "<NULL>")
415 << "), defaulting to MPICH";
416 return std::make_unique<MPICHImplTraits>(moduleOp);
417}
418
419//===----------------------------------------------------------------------===//
420// InitOpLowering
421//===----------------------------------------------------------------------===//
422
423struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
425
426 LogicalResult
427 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 Location loc = op.getLoc();
430
431 // ptrType `!llvm.ptr`
432 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
433
434 // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
435 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
436 Value llvmnull = nullPtrOp.getRes();
437
438 // grab a reference to the global module op:
439 auto moduleOp = op->getParentOfType<ModuleOp>();
440
441 // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
442 auto initFuncType =
443 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
444 // get or create function declaration:
445 LLVM::LLVMFuncOp initDecl =
446 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
447
448 // replace init with function call
449 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
450 ValueRange{llvmnull, llvmnull});
451
452 return success();
453 }
454};
455
456//===----------------------------------------------------------------------===//
457// FinalizeOpLowering
458//===----------------------------------------------------------------------===//
459
460struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
462
463 LogicalResult
464 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const override {
466 // get loc
467 Location loc = op.getLoc();
468
469 // grab a reference to the global module op:
470 auto moduleOp = op->getParentOfType<ModuleOp>();
471
472 // LLVM Function type representing `i32 MPI_Finalize()`
473 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
474 // get or create function declaration:
475 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
476 moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
477
478 // replace init with function call
479 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
480
481 return success();
482 }
483};
484
485//===----------------------------------------------------------------------===//
486// CommWorldOpLowering
487//===----------------------------------------------------------------------===//
488
489struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
491
492 LogicalResult
493 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 // grab a reference to the global module op:
496 auto moduleOp = op->getParentOfType<ModuleOp>();
497 auto mpiTraits = MPIImplTraits::get(moduleOp);
498 // get MPI_COMM_WORLD
499 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
500
501 return success();
502 }
503};
504
505//===----------------------------------------------------------------------===//
506// CommSplitOpLowering
507//===----------------------------------------------------------------------===//
508
509struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
511
512 LogicalResult
513 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 // grab a reference to the global module op:
516 auto moduleOp = op->getParentOfType<ModuleOp>();
517 auto mpiTraits = MPIImplTraits::get(moduleOp);
518 Type i32 = rewriter.getI32Type();
519 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
520 Location loc = op.getLoc();
521
522 // get communicator
523 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
524 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
525 auto outPtr =
526 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
527
528 // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
529 auto funcType =
530 LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
531 // get or create function declaration:
532 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
533 "MPI_Comm_split", funcType);
534
535 auto callOp =
536 LLVM::CallOp::create(rewriter, loc, funcDecl,
537 ValueRange{comm, adaptor.getColor(),
538 adaptor.getKey(), outPtr.getRes()});
539
540 // load the communicator into a register
541 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
542 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
543
544 // if retval is checked, replace uses of retval with the results from the
545 // call op
546 SmallVector<Value> replacements;
547 if (op.getRetval())
548 replacements.push_back(callOp.getResult());
549
550 // replace op
551 replacements.push_back(res);
552 rewriter.replaceOp(op, replacements);
553
554 return success();
555 }
556};
557
558//===----------------------------------------------------------------------===//
559// CommRankOpLowering
560//===----------------------------------------------------------------------===//
561
562struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
564
565 LogicalResult
566 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
567 ConversionPatternRewriter &rewriter) const override {
568 // get some helper vars
569 Location loc = op.getLoc();
570 MLIRContext *context = rewriter.getContext();
571 Type i32 = rewriter.getI32Type();
572
573 // ptrType `!llvm.ptr`
574 Type ptrType = LLVM::LLVMPointerType::get(context);
575
576 // grab a reference to the global module op:
577 auto moduleOp = op->getParentOfType<ModuleOp>();
578
579 auto mpiTraits = MPIImplTraits::get(moduleOp);
580 // get communicator
581 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
582
583 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
584 auto rankFuncType =
585 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
586 // get or create function declaration:
587 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
588 moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
589
590 // replace with function call
591 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
592 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
593 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
594 ValueRange{comm, rankptr.getRes()});
595
596 // load the rank into a register
597 auto loadedRank =
598 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
599
600 // if retval is checked, replace uses of retval with the results from the
601 // call op
602 SmallVector<Value> replacements;
603 if (op.getRetval())
604 replacements.push_back(callOp.getResult());
605
606 // replace all uses, then erase op
607 replacements.push_back(loadedRank.getRes());
608 rewriter.replaceOp(op, replacements);
609
610 return success();
611 }
612};
613
614//===----------------------------------------------------------------------===//
615// CommSizeOpLowering
616//===----------------------------------------------------------------------===//
617
618static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
619 Location loc, Value commOrg,
620 Value commAdapt) {
621 auto i32 = rewriter.getI32Type();
622 auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
623 if (succeeded(FoldToDLTIConst(nRanksOp, "MPI:comm_world_size", rewriter)))
624 return nRanksOp.getSize();
625 rewriter.eraseOp(nRanksOp);
626 return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
627}
628
629struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
631
632 LogicalResult
633 matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
634 ConversionPatternRewriter &rewriter) const override {
635 // get some helper vars
636 Location loc = op.getLoc();
637 MLIRContext *context = rewriter.getContext();
638 Type i32 = rewriter.getI32Type();
639
640 // ptrType `!llvm.ptr`
641 Type ptrType = LLVM::LLVMPointerType::get(context);
642
643 // grab a reference to the global module op:
644 auto moduleOp = op->getParentOfType<ModuleOp>();
645
646 auto mpiTraits = MPIImplTraits::get(moduleOp);
647 // get communicator
648 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
649
650 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
651 auto SizeFuncType =
652 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
653 // get or create function declaration:
654 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
655 moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
656
657 // replace with function call
658 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
659 auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
660 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
661 ValueRange{comm, sizeptr.getRes()});
662
663 // load the Size into a register
664 auto loadedSize =
665 LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
666
667 // if retval is checked, replace uses of retval with the results from the
668 // call op
669 SmallVector<Value> replacements;
670 if (op.getRetval())
671 replacements.push_back(callOp.getResult());
672
673 // replace all uses, then erase op
674 replacements.push_back(loadedSize.getRes());
675 rewriter.replaceOp(op, replacements);
676
677 return success();
678 }
679};
680
681//===----------------------------------------------------------------------===//
682// SendOpLowering
683//===----------------------------------------------------------------------===//
684
685struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
687
688 LogicalResult
689 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
690 ConversionPatternRewriter &rewriter) const override {
691 // get some helper vars
692 Location loc = op.getLoc();
693 MLIRContext *context = rewriter.getContext();
694 Type i32 = rewriter.getI32Type();
695 Type elemType = op.getRef().getType().getElementType();
696 int64_t rank = op.getRef().getType().getRank();
697
698 // ptrType `!llvm.ptr`
699 Type ptrType = LLVM::LLVMPointerType::get(context);
700
701 // grab a reference to the global module op:
702 auto moduleOp = op->getParentOfType<ModuleOp>();
703
704 // get MPI_COMM_WORLD, dataType and pointer
705 auto [dataPtr, size] =
706 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
707 auto mpiTraits = MPIImplTraits::get(moduleOp);
708 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
709 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
710
711 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
712 // tag, comm)`
713 auto funcType = LLVM::LLVMFunctionType::get(
714 i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
715 // get or create function declaration:
716 LLVM::LLVMFuncOp funcDecl =
717 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
718
719 // replace op with function call
720 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
721 ValueRange{dataPtr, size, dataType,
722 adaptor.getDest(),
723 adaptor.getTag(), comm});
724 if (op.getRetval())
725 rewriter.replaceOp(op, funcCall.getResult());
726 else
727 rewriter.eraseOp(op);
728
729 return success();
730 }
731};
732
733//===----------------------------------------------------------------------===//
734// RecvOpLowering
735//===----------------------------------------------------------------------===//
736
737struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
739
740 LogicalResult
741 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
742 ConversionPatternRewriter &rewriter) const override {
743 // get some helper vars
744 Location loc = op.getLoc();
745 MLIRContext *context = rewriter.getContext();
746 Type i32 = rewriter.getI32Type();
747 Type i64 = rewriter.getI64Type();
748 Type elemType = op.getRef().getType().getElementType();
749 int64_t rank = op.getRef().getType().getRank();
750
751 // ptrType `!llvm.ptr`
752 Type ptrType = LLVM::LLVMPointerType::get(context);
753
754 // grab a reference to the global module op:
755 auto moduleOp = op->getParentOfType<ModuleOp>();
756
757 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
758 auto [dataPtr, size] =
759 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
760 auto mpiTraits = MPIImplTraits::get(moduleOp);
761 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
762 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
763 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
764 mpiTraits->getStatusIgnore());
765 statusIgnore =
766 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
767
768 // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
769 // tag, comm)`
770 auto funcType =
771 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
772 i32, comm.getType(), ptrType});
773 // get or create function declaration:
774 LLVM::LLVMFuncOp funcDecl =
775 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
776
777 // replace op with function call
778 auto funcCall = LLVM::CallOp::create(
779 rewriter, loc, funcDecl,
780 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
781 adaptor.getTag(), comm, statusIgnore});
782 if (op.getRetval())
783 rewriter.replaceOp(op, funcCall.getResult());
784 else
785 rewriter.eraseOp(op);
786
787 return success();
788 }
789};
790
791//===----------------------------------------------------------------------===//
792// AllGatherOpLowering
793//===----------------------------------------------------------------------===//
794
795struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
797
798 LogicalResult
799 matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
800 ConversionPatternRewriter &rewriter) const override {
801 Location loc = op.getLoc();
802 MLIRContext *context = rewriter.getContext();
803 Type sElemType = op.getSendbuf().getType().getElementType();
804 Type rElemType = op.getRecvbuf().getType().getElementType();
805 int64_t sRank = op.getSendbuf().getType().getRank();
806 int64_t rRank = op.getRecvbuf().getType().getRank();
807 auto [sendPtr, sendSize] =
808 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
809 auto [recvPtr, recvSize] =
810 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
811
812 auto moduleOp = op->getParentOfType<ModuleOp>();
813 auto mpiTraits = MPIImplTraits::get(moduleOp);
814 Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
815 Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
816 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
817
818 Type ptrType = LLVM::LLVMPointerType::get(context);
819 Type i32 = rewriter.getI32Type();
820 // int MPI_Allgather(
821 // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
822 // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
823 // MPI_Comm communicator);
824 auto funcType = LLVM::LLVMFunctionType::get(
825 i32, {ptrType, i32, sDataType.getType(), ptrType, i32,
826 rDataType.getType(), comm.getType()});
827 // get or create function declaration:
828 LLVM::LLVMFuncOp funcDecl =
829 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
830
831 // count_recv is the number of elements received from each rank, not total
832 Value nRanks =
833 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
834 Value recvCountPerRank =
835 LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
836
837 // replace op with function call
838 auto funcCall =
839 LLVM::CallOp::create(rewriter, loc, funcDecl,
840 ValueRange{sendPtr, sendSize, sDataType, recvPtr,
841 recvCountPerRank, rDataType, comm});
842
843 if (op.getRetval())
844 rewriter.replaceOp(op, funcCall.getResult());
845 else
846 rewriter.eraseOp(op);
847
848 return success();
849 }
850};
851
852//===----------------------------------------------------------------------===//
853// AllReduceOpLowering
854//===----------------------------------------------------------------------===//
855
856struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
858
859 LogicalResult
860 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 Location loc = op.getLoc();
863 MLIRContext *context = rewriter.getContext();
864 Type i32 = rewriter.getI32Type();
865 Type i64 = rewriter.getI64Type();
866 Type elemType = op.getSendbuf().getType().getElementType();
867 int64_t sRank = op.getSendbuf().getType().getRank();
868 int64_t rRank = op.getRecvbuf().getType().getRank();
869
870 // ptrType `!llvm.ptr`
871 Type ptrType = LLVM::LLVMPointerType::get(context);
872 auto moduleOp = op->getParentOfType<ModuleOp>();
873 auto mpiTraits = MPIImplTraits::get(moduleOp);
874 auto [sendPtr, sendSize] =
875 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
876 auto [recvPtr, recvSize] =
877 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
878
879 // If input and output are the same, request in-place operation.
880 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
881 sendPtr = LLVM::ConstantOp::create(
882 rewriter, loc, i64,
883 reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
884 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
885 }
886
887 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
888 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
889 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
890
891 // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
892 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
893 auto funcType = LLVM::LLVMFunctionType::get(
894 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
895 commWorld.getType()});
896 // get or create function declaration:
897 LLVM::LLVMFuncOp funcDecl =
898 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
899
900 // replace op with function call
901 auto funcCall = LLVM::CallOp::create(
902 rewriter, loc, funcDecl,
903 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
904
905 if (op.getRetval())
906 rewriter.replaceOp(op, funcCall.getResult());
907 else
908 rewriter.eraseOp(op);
909
910 return success();
911 }
912};
913
914//===----------------------------------------------------------------------===//
915// ConvertToLLVMPatternInterface implementation
916//===----------------------------------------------------------------------===//
917
918/// Implement the interface to convert Func to LLVM.
919struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
921 /// Hook for derived dialect interface to provide conversion patterns
922 /// and mark dialect legal for the conversion target.
923 void populateConvertToLLVMConversionPatterns(
924 ConversionTarget &target, LLVMTypeConverter &typeConverter,
925 RewritePatternSet &patterns) const final {
927 }
928};
929} // namespace
930
931//===----------------------------------------------------------------------===//
932// Pattern Population
933//===----------------------------------------------------------------------===//
934
937 // Using i64 as a portable, intermediate type for !mpi.comm.
938 // It would be nicer to somehow get the right type directly, but TLDI is not
939 // available here.
940 converter.addConversion([](mpi::CommType type) {
941 return IntegerType::get(type.getContext(), 64);
942 });
943 patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
944 CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
945 SendOpLowering, RecvOpLowering, AllGatherOpLowering,
946 AllReduceOpLowering>(converter);
947}
948
950 registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
951 dialect->addInterfaces<FuncToLLVMDialectInterface>();
952 });
953}
return success()
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition DLTI.cpp:537
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
Definition Utils.h:19
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...