MLIR 23.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//
10// Copyright (C) by Argonne National Laboratory
11// See COPYRIGHT in top-level directory
12// of MPICH source repository.
13//
14
26#include <memory>
27
28using namespace mlir;
29
30namespace {
31
32template <typename Op, typename... Args>
33static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
34 ConversionPatternRewriter &rewriter, StringRef name,
35 Args &&...args) {
36 Op ret;
37 if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
38 ConversionPatternRewriter::InsertionGuard guard(rewriter);
39 rewriter.setInsertionPointToStart(moduleOp.getBody());
40 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
41 }
42 return ret;
43}
44
45static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
46 const Location loc,
47 ConversionPatternRewriter &rewriter,
48 StringRef name,
49 LLVM::LLVMFunctionType type) {
50 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
51 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
52}
53
54std::pair<Value, Value> getRawPtrAndSize(const Location loc,
55 ConversionPatternRewriter &rewriter,
56 Value memRef, int64_t rank,
57 Type elType) {
58 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
59 Value dataPtr =
60 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
61 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
62 rewriter.getI64Type(), memRef, 2);
63 Value resPtr =
64 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
65 Value size = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
66 rewriter.getIndexAttr(1));
67 if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
68 for (int64_t i = 0; i < rank; ++i) {
69 Value dim = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
70 ArrayRef<int64_t>{3, i});
71 dim = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), dim);
72 size =
73 LLVM::MulOp::create(rewriter, loc, rewriter.getI32Type(), dim, size);
74 }
75 } else {
76 size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
77 }
78 return {resPtr, size};
79}
80
81/// When lowering the mpi dialect to functions calls certain details
82/// differ between various MPI implementations. This class will provide
83/// these in a generic way, depending on the MPI implementation that got
84/// selected by the DLTI attribute on the module.
85class MPIImplTraits {
86 ModuleOp &moduleOp;
87
88public:
89 /// Instantiate a new MPIImplTraits object according to the DLTI attribute
90 /// on the given module. Default to MPICH if no attribute is present or
91 /// the value is unknown.
92 static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
93
94 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
95
96 virtual ~MPIImplTraits() = default;
97
98 ModuleOp &getModuleOp() { return moduleOp; }
99
100 /// Gets or creates MPI_COMM_WORLD as a Value.
101 /// Different MPI implementations have different communicator types.
102 /// Using i64 as a portable, intermediate type.
103 /// Appropriate cast needs to take place before calling MPI functions.
104 virtual Value getCommWorld(Location loc,
105 ConversionPatternRewriter &rewriter) = 0;
106
107 /// Type converter provides i64 type for communicator type.
108 /// Converts to native type, which might be ptr or int or whatever.
109 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
110 Value comm) = 0;
111
112 /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
113 virtual intptr_t getStatusIgnore() = 0;
114
115 /// Get the MPI_IN_PLACE value (void *).
116 virtual void *getInPlace() = 0;
117
118 /// Gets or creates an MPI datatype as a value which corresponds to the given
119 /// type.
120 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
121 Type type) = 0;
122
123 /// Gets or creates an MPI_Op value which corresponds to the given
124 /// enum value.
125 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
126 mpi::MPI_ReductionOpEnum opAttr) = 0;
127};
128
129//===----------------------------------------------------------------------===//
130// Implementation details for MPICH ABI compatible MPI implementations
131//===----------------------------------------------------------------------===//
132
133class MPICHImplTraits : public MPIImplTraits {
134 static constexpr int MPI_FLOAT = 0x4c00040a;
135 static constexpr int MPI_DOUBLE = 0x4c00080b;
136 static constexpr int MPI_INT8_T = 0x4c000137;
137 static constexpr int MPI_INT16_T = 0x4c000238;
138 static constexpr int MPI_INT32_T = 0x4c000439;
139 static constexpr int MPI_INT64_T = 0x4c00083a;
140 static constexpr int MPI_UINT8_T = 0x4c00013b;
141 static constexpr int MPI_UINT16_T = 0x4c00023c;
142 static constexpr int MPI_UINT32_T = 0x4c00043d;
143 static constexpr int MPI_UINT64_T = 0x4c00083e;
144 static constexpr int MPI_MAX = 0x58000001;
145 static constexpr int MPI_MIN = 0x58000002;
146 static constexpr int MPI_SUM = 0x58000003;
147 static constexpr int MPI_PROD = 0x58000004;
148 static constexpr int MPI_LAND = 0x58000005;
149 static constexpr int MPI_BAND = 0x58000006;
150 static constexpr int MPI_LOR = 0x58000007;
151 static constexpr int MPI_BOR = 0x58000008;
152 static constexpr int MPI_LXOR = 0x58000009;
153 static constexpr int MPI_BXOR = 0x5800000a;
154 static constexpr int MPI_MINLOC = 0x5800000b;
155 static constexpr int MPI_MAXLOC = 0x5800000c;
156 static constexpr int MPI_REPLACE = 0x5800000d;
157 static constexpr int MPI_NO_OP = 0x5800000e;
158
159public:
160 using MPIImplTraits::MPIImplTraits;
161
162 ~MPICHImplTraits() override = default;
163
164 Value getCommWorld(const Location loc,
165 ConversionPatternRewriter &rewriter) override {
166 static constexpr int MPI_COMM_WORLD = 0x44000000;
167 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
168 MPI_COMM_WORLD);
169 }
170
171 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
172 Value comm) override {
173 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
174 }
175
176 intptr_t getStatusIgnore() override { return 1; }
177
178 void *getInPlace() override { return reinterpret_cast<void *>(-1); }
179
180 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
181 Type type) override {
182 int32_t mtype = 0;
183 if (type.isF32())
184 mtype = MPI_FLOAT;
185 else if (type.isF64())
186 mtype = MPI_DOUBLE;
187 else if (type.isInteger(64) && !type.isUnsignedInteger())
188 mtype = MPI_INT64_T;
189 else if (type.isInteger(64))
190 mtype = MPI_UINT64_T;
191 else if (type.isInteger(32) && !type.isUnsignedInteger())
192 mtype = MPI_INT32_T;
193 else if (type.isInteger(32))
194 mtype = MPI_UINT32_T;
195 else if (type.isInteger(16) && !type.isUnsignedInteger())
196 mtype = MPI_INT16_T;
197 else if (type.isInteger(16))
198 mtype = MPI_UINT16_T;
199 else if (type.isInteger(8) && !type.isUnsignedInteger())
200 mtype = MPI_INT8_T;
201 else if (type.isInteger(8))
202 mtype = MPI_UINT8_T;
203 else
204 assert(false && "unsupported type");
205 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
206 mtype);
207 }
208
209 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
210 mpi::MPI_ReductionOpEnum opAttr) override {
211 int32_t op = MPI_NO_OP;
212 switch (opAttr) {
213 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
214 op = MPI_NO_OP;
215 break;
216 case mpi::MPI_ReductionOpEnum::MPI_MAX:
217 op = MPI_MAX;
218 break;
219 case mpi::MPI_ReductionOpEnum::MPI_MIN:
220 op = MPI_MIN;
221 break;
222 case mpi::MPI_ReductionOpEnum::MPI_SUM:
223 op = MPI_SUM;
224 break;
225 case mpi::MPI_ReductionOpEnum::MPI_PROD:
226 op = MPI_PROD;
227 break;
228 case mpi::MPI_ReductionOpEnum::MPI_LAND:
229 op = MPI_LAND;
230 break;
231 case mpi::MPI_ReductionOpEnum::MPI_BAND:
232 op = MPI_BAND;
233 break;
234 case mpi::MPI_ReductionOpEnum::MPI_LOR:
235 op = MPI_LOR;
236 break;
237 case mpi::MPI_ReductionOpEnum::MPI_BOR:
238 op = MPI_BOR;
239 break;
240 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
241 op = MPI_LXOR;
242 break;
243 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
244 op = MPI_BXOR;
245 break;
246 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
247 op = MPI_MINLOC;
248 break;
249 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
250 op = MPI_MAXLOC;
251 break;
252 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
253 op = MPI_REPLACE;
254 break;
255 }
256 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
257 }
258};
259
260//===----------------------------------------------------------------------===//
261// Implementation details for OpenMPI
262//===----------------------------------------------------------------------===//
263class OMPIImplTraits : public MPIImplTraits {
264 LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
265 ConversionPatternRewriter &rewriter,
266 StringRef name,
267 LLVM::LLVMStructType type) {
268
269 return getOrDefineGlobal<LLVM::GlobalOp>(
270 getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
271 LLVM::Linkage::External, name,
272 /*value=*/Attribute(), /*alignment=*/0, 0);
273 }
274
275public:
276 using MPIImplTraits::MPIImplTraits;
277
278 ~OMPIImplTraits() override = default;
279
280 Value getCommWorld(const Location loc,
281 ConversionPatternRewriter &rewriter) override {
282 auto *context = rewriter.getContext();
283 // get external opaque struct pointer type
284 auto commStructT =
285 LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
286 StringRef name = "ompi_mpi_comm_world";
287
288 // make sure global op definition exists
289 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
290
291 // get address of symbol
292 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
293 LLVM::LLVMPointerType::get(context),
294 SymbolRefAttr::get(context, name));
295 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
296 }
297
298 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
299 Value comm) override {
300 return LLVM::IntToPtrOp::create(
301 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
302 }
303
304 intptr_t getStatusIgnore() override { return 0; }
305
306 void *getInPlace() override { return reinterpret_cast<void *>(1); }
307
308 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
309 Type type) override {
310 StringRef mtype;
311 if (type.isF32())
312 mtype = "ompi_mpi_float";
313 else if (type.isF64())
314 mtype = "ompi_mpi_double";
315 else if (type.isInteger(64) && !type.isUnsignedInteger())
316 mtype = "ompi_mpi_int64_t";
317 else if (type.isInteger(64))
318 mtype = "ompi_mpi_uint64_t";
319 else if (type.isInteger(32) && !type.isUnsignedInteger())
320 mtype = "ompi_mpi_int32_t";
321 else if (type.isInteger(32))
322 mtype = "ompi_mpi_uint32_t";
323 else if (type.isInteger(16) && !type.isUnsignedInteger())
324 mtype = "ompi_mpi_int16_t";
325 else if (type.isInteger(16))
326 mtype = "ompi_mpi_uint16_t";
327 else if (type.isInteger(8) && !type.isUnsignedInteger())
328 mtype = "ompi_mpi_int8_t";
329 else if (type.isInteger(8))
330 mtype = "ompi_mpi_uint8_t";
331 else
332 assert(false && "unsupported type");
333
334 auto *context = rewriter.getContext();
335 // get external opaque struct pointer type
336 auto typeStructT =
337 LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
338 // make sure global op definition exists
339 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
340 // get address of symbol
341 return LLVM::AddressOfOp::create(rewriter, loc,
342 LLVM::LLVMPointerType::get(context),
343 SymbolRefAttr::get(context, mtype));
344 }
345
346 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
347 mpi::MPI_ReductionOpEnum opAttr) override {
348 StringRef op;
349 switch (opAttr) {
350 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
351 op = "ompi_mpi_no_op";
352 break;
353 case mpi::MPI_ReductionOpEnum::MPI_MAX:
354 op = "ompi_mpi_max";
355 break;
356 case mpi::MPI_ReductionOpEnum::MPI_MIN:
357 op = "ompi_mpi_min";
358 break;
359 case mpi::MPI_ReductionOpEnum::MPI_SUM:
360 op = "ompi_mpi_sum";
361 break;
362 case mpi::MPI_ReductionOpEnum::MPI_PROD:
363 op = "ompi_mpi_prod";
364 break;
365 case mpi::MPI_ReductionOpEnum::MPI_LAND:
366 op = "ompi_mpi_land";
367 break;
368 case mpi::MPI_ReductionOpEnum::MPI_BAND:
369 op = "ompi_mpi_band";
370 break;
371 case mpi::MPI_ReductionOpEnum::MPI_LOR:
372 op = "ompi_mpi_lor";
373 break;
374 case mpi::MPI_ReductionOpEnum::MPI_BOR:
375 op = "ompi_mpi_bor";
376 break;
377 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
378 op = "ompi_mpi_lxor";
379 break;
380 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
381 op = "ompi_mpi_bxor";
382 break;
383 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
384 op = "ompi_mpi_minloc";
385 break;
386 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
387 op = "ompi_mpi_maxloc";
388 break;
389 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
390 op = "ompi_mpi_replace";
391 break;
392 }
393 auto *context = rewriter.getContext();
394 // get external opaque struct pointer type
395 auto opStructT =
396 LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
397 // make sure global op definition exists
398 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
399 // get address of symbol
400 return LLVM::AddressOfOp::create(rewriter, loc,
401 LLVM::LLVMPointerType::get(context),
402 SymbolRefAttr::get(context, op));
403 }
404};
405
406std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
407 auto attr = dlti::query(moduleOp, {"MPI:Implementation"}, false);
408 if (failed(attr))
409 return std::make_unique<MPICHImplTraits>(moduleOp);
410 auto strAttr = dyn_cast<StringAttr>(attr.value());
411 if (strAttr && strAttr.getValue() == "OpenMPI")
412 return std::make_unique<OMPIImplTraits>(moduleOp);
413 if (!strAttr || strAttr.getValue() != "MPICH")
414 moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
415 << (strAttr ? strAttr.getValue() : "<NULL>")
416 << "), defaulting to MPICH";
417 return std::make_unique<MPICHImplTraits>(moduleOp);
418}
419
420//===----------------------------------------------------------------------===//
421// InitOpLowering
422//===----------------------------------------------------------------------===//
423
424struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
426
427 LogicalResult
428 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter) const override {
430 Location loc = op.getLoc();
431
432 // ptrType `!llvm.ptr`
433 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
434
435 // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
436 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
437 Value llvmnull = nullPtrOp.getRes();
438
439 // grab a reference to the global module op:
440 auto moduleOp = op->getParentOfType<ModuleOp>();
441
442 // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
443 auto initFuncType =
444 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
445 // get or create function declaration:
446 LLVM::LLVMFuncOp initDecl =
447 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
448
449 // replace init with function call
450 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
451 ValueRange{llvmnull, llvmnull});
452
453 return success();
454 }
455};
456
457//===----------------------------------------------------------------------===//
458// FinalizeOpLowering
459//===----------------------------------------------------------------------===//
460
461struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
463
464 LogicalResult
465 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
466 ConversionPatternRewriter &rewriter) const override {
467 // get loc
468 Location loc = op.getLoc();
469
470 // grab a reference to the global module op:
471 auto moduleOp = op->getParentOfType<ModuleOp>();
472
473 // LLVM Function type representing `i32 MPI_Finalize()`
474 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
475 // get or create function declaration:
476 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
477 moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
478
479 // replace init with function call
480 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
481
482 return success();
483 }
484};
485
486//===----------------------------------------------------------------------===//
487// CommWorldOpLowering
488//===----------------------------------------------------------------------===//
489
490struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
492
493 LogicalResult
494 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter) const override {
496 // grab a reference to the global module op:
497 auto moduleOp = op->getParentOfType<ModuleOp>();
498 auto mpiTraits = MPIImplTraits::get(moduleOp);
499 // get MPI_COMM_WORLD
500 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
501
502 return success();
503 }
504};
505
506//===----------------------------------------------------------------------===//
507// CommSplitOpLowering
508//===----------------------------------------------------------------------===//
509
510struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
512
513 LogicalResult
514 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter) const override {
516 // grab a reference to the global module op:
517 auto moduleOp = op->getParentOfType<ModuleOp>();
518 auto mpiTraits = MPIImplTraits::get(moduleOp);
519 Type i32 = rewriter.getI32Type();
520 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
521 Location loc = op.getLoc();
522
523 // get communicator
524 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
525 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
526 auto outPtr =
527 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
528
529 // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
530 auto funcType =
531 LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
532 // get or create function declaration:
533 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
534 "MPI_Comm_split", funcType);
535
536 auto callOp =
537 LLVM::CallOp::create(rewriter, loc, funcDecl,
538 ValueRange{comm, adaptor.getColor(),
539 adaptor.getKey(), outPtr.getRes()});
540
541 // load the communicator into a register
542 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
543 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
544
545 // if retval is checked, replace uses of retval with the results from the
546 // call op
547 SmallVector<Value> replacements;
548 if (op.getRetval())
549 replacements.push_back(callOp.getResult());
550
551 // replace op
552 replacements.push_back(res);
553 rewriter.replaceOp(op, replacements);
554
555 return success();
556 }
557};
558
559//===----------------------------------------------------------------------===//
560// CommRankOpLowering
561//===----------------------------------------------------------------------===//
562
563struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
565
566 LogicalResult
567 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569 // get some helper vars
570 Location loc = op.getLoc();
571 MLIRContext *context = rewriter.getContext();
572 Type i32 = rewriter.getI32Type();
573
574 // ptrType `!llvm.ptr`
575 Type ptrType = LLVM::LLVMPointerType::get(context);
576
577 // grab a reference to the global module op:
578 auto moduleOp = op->getParentOfType<ModuleOp>();
579
580 auto mpiTraits = MPIImplTraits::get(moduleOp);
581 // get communicator
582 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
583
584 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
585 auto rankFuncType =
586 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
587 // get or create function declaration:
588 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
589 moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
590
591 // replace with function call
592 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
593 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
594 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
595 ValueRange{comm, rankptr.getRes()});
596
597 // load the rank into a register
598 auto loadedRank =
599 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
600
601 // if retval is checked, replace uses of retval with the results from the
602 // call op
603 SmallVector<Value> replacements;
604 if (op.getRetval())
605 replacements.push_back(callOp.getResult());
606
607 // replace all uses, then erase op
608 replacements.push_back(loadedRank.getRes());
609 rewriter.replaceOp(op, replacements);
610
611 return success();
612 }
613};
614
615//===----------------------------------------------------------------------===//
616// CommSizeOpLowering
617//===----------------------------------------------------------------------===//
618
619static Value createOrFoldCommSize(ConversionPatternRewriter &rewriter,
620 Location loc, Value commOrg,
621 Value commAdapt) {
622 auto i32 = rewriter.getI32Type();
623 auto nRanksOp = mpi::CommSizeOp::create(rewriter, loc, i32, commOrg);
624 if (succeeded(FoldToDLTIConst(nRanksOp, "MPI:comm_world_size", rewriter)))
625 return nRanksOp.getSize();
626 rewriter.eraseOp(nRanksOp);
627 return mpi::CommSizeOp::create(rewriter, loc, i32, commAdapt).getSize();
628}
629
630struct CommSizeOpLowering : public ConvertOpToLLVMPattern<mpi::CommSizeOp> {
632
633 LogicalResult
634 matchAndRewrite(mpi::CommSizeOp op, OpAdaptor adaptor,
635 ConversionPatternRewriter &rewriter) const override {
636 // get some helper vars
637 Location loc = op.getLoc();
638 MLIRContext *context = rewriter.getContext();
639 Type i32 = rewriter.getI32Type();
640
641 // ptrType `!llvm.ptr`
642 Type ptrType = LLVM::LLVMPointerType::get(context);
643
644 // grab a reference to the global module op:
645 auto moduleOp = op->getParentOfType<ModuleOp>();
646
647 auto mpiTraits = MPIImplTraits::get(moduleOp);
648 // get communicator
649 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
650
651 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
652 auto SizeFuncType =
653 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
654 // get or create function declaration:
655 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
656 moduleOp, loc, rewriter, "MPI_Comm_size", SizeFuncType);
657
658 // replace with function call
659 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
660 auto sizeptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
661 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
662 ValueRange{comm, sizeptr.getRes()});
663
664 // load the Size into a register
665 auto loadedSize =
666 LLVM::LoadOp::create(rewriter, loc, i32, sizeptr.getResult());
667
668 // if retval is checked, replace uses of retval with the results from the
669 // call op
670 SmallVector<Value> replacements;
671 if (op.getRetval())
672 replacements.push_back(callOp.getResult());
673
674 // replace all uses, then erase op
675 replacements.push_back(loadedSize.getRes());
676 rewriter.replaceOp(op, replacements);
677
678 return success();
679 }
680};
681
682//===----------------------------------------------------------------------===//
683// SendOpLowering
684//===----------------------------------------------------------------------===//
685
686struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
688
689 LogicalResult
690 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter) const override {
692 // get some helper vars
693 Location loc = op.getLoc();
694 MLIRContext *context = rewriter.getContext();
695 Type i32 = rewriter.getI32Type();
696 Type elemType = op.getRef().getType().getElementType();
697 int64_t rank = op.getRef().getType().getRank();
698
699 // ptrType `!llvm.ptr`
700 Type ptrType = LLVM::LLVMPointerType::get(context);
701
702 // grab a reference to the global module op:
703 auto moduleOp = op->getParentOfType<ModuleOp>();
704
705 // get MPI_COMM_WORLD, dataType and pointer
706 auto [dataPtr, size] =
707 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
708 auto mpiTraits = MPIImplTraits::get(moduleOp);
709 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
710 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
711
712 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
713 // tag, comm)`
714 auto funcType = LLVM::LLVMFunctionType::get(
715 i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
716 // get or create function declaration:
717 LLVM::LLVMFuncOp funcDecl =
718 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
719
720 // replace op with function call
721 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
722 ValueRange{dataPtr, size, dataType,
723 adaptor.getDest(),
724 adaptor.getTag(), comm});
725 if (op.getRetval())
726 rewriter.replaceOp(op, funcCall.getResult());
727 else
728 rewriter.eraseOp(op);
729
730 return success();
731 }
732};
733
734//===----------------------------------------------------------------------===//
735// RecvOpLowering
736//===----------------------------------------------------------------------===//
737
738struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
740
741 LogicalResult
742 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
743 ConversionPatternRewriter &rewriter) const override {
744 // get some helper vars
745 Location loc = op.getLoc();
746 MLIRContext *context = rewriter.getContext();
747 Type i32 = rewriter.getI32Type();
748 Type i64 = rewriter.getI64Type();
749 Type elemType = op.getRef().getType().getElementType();
750 int64_t rank = op.getRef().getType().getRank();
751
752 // ptrType `!llvm.ptr`
753 Type ptrType = LLVM::LLVMPointerType::get(context);
754
755 // grab a reference to the global module op:
756 auto moduleOp = op->getParentOfType<ModuleOp>();
757
758 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
759 auto [dataPtr, size] =
760 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), rank, elemType);
761 auto mpiTraits = MPIImplTraits::get(moduleOp);
762 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
763 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
764 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
765 mpiTraits->getStatusIgnore());
766 statusIgnore =
767 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
768
769 // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
770 // tag, comm)`
771 auto funcType =
772 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
773 i32, comm.getType(), ptrType});
774 // get or create function declaration:
775 LLVM::LLVMFuncOp funcDecl =
776 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
777
778 // replace op with function call
779 auto funcCall = LLVM::CallOp::create(
780 rewriter, loc, funcDecl,
781 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
782 adaptor.getTag(), comm, statusIgnore});
783 if (op.getRetval())
784 rewriter.replaceOp(op, funcCall.getResult());
785 else
786 rewriter.eraseOp(op);
787
788 return success();
789 }
790};
791
792//===----------------------------------------------------------------------===//
793// AllGatherOpLowering
794//===----------------------------------------------------------------------===//
795
796struct AllGatherOpLowering : public ConvertOpToLLVMPattern<mpi::AllGatherOp> {
798
799 LogicalResult
800 matchAndRewrite(mpi::AllGatherOp op, OpAdaptor adaptor,
801 ConversionPatternRewriter &rewriter) const override {
802 Location loc = op.getLoc();
803 MLIRContext *context = rewriter.getContext();
804 Type sElemType = op.getSendbuf().getType().getElementType();
805 Type rElemType = op.getRecvbuf().getType().getElementType();
806 int64_t sRank = op.getSendbuf().getType().getRank();
807 int64_t rRank = op.getRecvbuf().getType().getRank();
808 auto [sendPtr, sendSize] =
809 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, sElemType);
810 auto [recvPtr, recvSize] =
811 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, rElemType);
812
813 auto moduleOp = op->getParentOfType<ModuleOp>();
814 auto mpiTraits = MPIImplTraits::get(moduleOp);
815 Value sDataType = mpiTraits->getDataType(loc, rewriter, sElemType);
816 Value rDataType = mpiTraits->getDataType(loc, rewriter, rElemType);
817 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
818
819 Type ptrType = LLVM::LLVMPointerType::get(context);
820 Type i32 = rewriter.getI32Type();
821 // int MPI_Allgather(
822 // const void* buffer_send, int count_send, MPI_Datatype datatype_send,
823 // void* buffer_recv, int count_recv, MPI_Datatype datatype_recv,
824 // MPI_Comm communicator);
825 auto funcType = LLVM::LLVMFunctionType::get(
826 i32, {ptrType, i32, sDataType.getType(), ptrType, i32,
827 rDataType.getType(), comm.getType()});
828 // get or create function declaration:
829 LLVM::LLVMFuncOp funcDecl =
830 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allgather", funcType);
831
832 // count_recv is the number of elements received from each rank, not total
833 Value nRanks =
834 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
835 Value recvCountPerRank =
836 LLVM::UDivOp::create(rewriter, loc, i32, recvSize, nRanks);
837
838 // replace op with function call
839 auto funcCall =
840 LLVM::CallOp::create(rewriter, loc, funcDecl,
841 ValueRange{sendPtr, sendSize, sDataType, recvPtr,
842 recvCountPerRank, rDataType, comm});
843
844 if (op.getRetval())
845 rewriter.replaceOp(op, funcCall.getResult());
846 else
847 rewriter.eraseOp(op);
848
849 return success();
850 }
851};
852
853//===----------------------------------------------------------------------===//
854// AllReduceOpLowering
855//===----------------------------------------------------------------------===//
856
857struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
859
860 LogicalResult
861 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
862 ConversionPatternRewriter &rewriter) const override {
863 Location loc = op.getLoc();
864 MLIRContext *context = rewriter.getContext();
865 Type i32 = rewriter.getI32Type();
866 Type i64 = rewriter.getI64Type();
867 Type elemType = op.getSendbuf().getType().getElementType();
868 int64_t sRank = op.getSendbuf().getType().getRank();
869 int64_t rRank = op.getRecvbuf().getType().getRank();
870
871 // ptrType `!llvm.ptr`
872 Type ptrType = LLVM::LLVMPointerType::get(context);
873 auto moduleOp = op->getParentOfType<ModuleOp>();
874 auto mpiTraits = MPIImplTraits::get(moduleOp);
875 auto [sendPtr, sendSize] =
876 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
877 auto [recvPtr, recvSize] =
878 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
879
880 // If input and output are the same, request in-place operation.
881 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
882 sendPtr = LLVM::ConstantOp::create(
883 rewriter, loc, i64,
884 reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
885 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
886 }
887
888 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
889 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
890 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
891
892 // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
893 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
894 auto funcType = LLVM::LLVMFunctionType::get(
895 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
896 commWorld.getType()});
897 // get or create function declaration:
898 LLVM::LLVMFuncOp funcDecl =
899 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
900
901 // replace op with function call
902 auto funcCall = LLVM::CallOp::create(
903 rewriter, loc, funcDecl,
904 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
905
906 if (op.getRetval())
907 rewriter.replaceOp(op, funcCall.getResult());
908 else
909 rewriter.eraseOp(op);
910
911 return success();
912 }
913};
914
915//===----------------------------------------------------------------------===//
916// ReduceScatterBlockOpLowering
917//===----------------------------------------------------------------------===//
918
919struct ReduceScatterBlockOpLowering
920 : public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
922
923 LogicalResult
924 matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
925 ConversionPatternRewriter &rewriter) const override {
926 Location loc = op.getLoc();
927 MLIRContext *context = rewriter.getContext();
928 Type i32 = rewriter.getI32Type();
929 Type i64 = rewriter.getI64Type();
930 Type elemType = op.getSendbuf().getType().getElementType();
931 int64_t sRank = op.getSendbuf().getType().getRank();
932 int64_t rRank = op.getRecvbuf().getType().getRank();
933
934 // ptrType `!llvm.ptr`
935 Type ptrType = LLVM::LLVMPointerType::get(context);
936 auto moduleOp = op->getParentOfType<ModuleOp>();
937 auto mpiTraits = MPIImplTraits::get(moduleOp);
938 auto [sendPtr, sendSize] =
939 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
940 auto [recvPtr, recvSize] =
941 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
942
943 // If input and output are the same, request in-place operation.
944 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
945 sendPtr = LLVM::ConstantOp::create(
946 rewriter, loc, i64,
947 reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
948 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
949 }
950
951 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
952 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
953 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
954
955 Value nRanks =
956 createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
957 Value totalExpected =
958 LLVM::MulOp::create(rewriter, loc, i32, recvSize, nRanks);
959 Value sizeIsValid = LLVM::ICmpOp::create(
960 rewriter, loc, LLVM::ICmpPredicate::eq, sendSize, totalExpected);
961 cf::AssertOp::create(rewriter, loc, sizeIsValid,
962 "Send buffer's size must be the receive buffer's size "
963 "times the number of ranks");
964
965 // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
966 // int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
967 auto funcType = LLVM::LLVMFunctionType::get(
968 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
969 comm.getType()});
970 // get or create function declaration:
971 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
972 moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
973
974 // replace op with function call
975 auto funcCall = LLVM::CallOp::create(
976 rewriter, loc, funcDecl,
977 ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
978
979 if (op.getRetval())
980 rewriter.replaceOp(op, funcCall.getResult());
981 else
982 rewriter.eraseOp(op);
983
984 return success();
985 }
986};
987
988//===----------------------------------------------------------------------===//
989// ConvertToLLVMPatternInterface implementation
990//===----------------------------------------------------------------------===//
991
992/// Implement the interface to convert Func to LLVM.
993struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
995 /// Hook for derived dialect interface to provide conversion patterns
996 /// and mark dialect legal for the conversion target.
997 void populateConvertToLLVMConversionPatterns(
998 ConversionTarget &target, LLVMTypeConverter &typeConverter,
999 RewritePatternSet &patterns) const final {
1000 mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
1001 }
1002};
1003} // namespace
1004
1005//===----------------------------------------------------------------------===//
1006// Pattern Population
1007//===----------------------------------------------------------------------===//
1008
1010 RewritePatternSet &patterns) {
1011 // Using i64 as a portable, intermediate type for !mpi.comm.
1012 // It would be nicer to somehow get the right type directly, but TLDI is not
1013 // available here.
1014 converter.addConversion([](mpi::CommType type) {
1015 return IntegerType::get(type.getContext(), 64);
1016 });
1017 patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
1018 CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
1019 SendOpLowering, RecvOpLowering, AllGatherOpLowering,
1020 AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
1021}
1022
1024 registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
1025 dialect->addInterfaces<FuncToLLVMDialectInterface>();
1026 });
1027}
return success()
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This provides public APIs that all operations should have.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition DLTI.cpp:537
LogicalResult FoldToDLTIConst(OpT op, const char *key, mlir::PatternRewriter &b)
Definition Utils.h:19
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...