MLIR 22.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//
10// Copyright (C) by Argonne National Laboratory
11// See COPYRIGHT in top-level directory
12// of MPICH source repository.
13//
14
24#include <memory>
25
26using namespace mlir;
27
28namespace {
29
30template <typename Op, typename... Args>
31static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
32 ConversionPatternRewriter &rewriter, StringRef name,
33 Args &&...args) {
34 Op ret;
35 if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
37 rewriter.setInsertionPointToStart(moduleOp.getBody());
38 ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
39 }
40 return ret;
41}
42
43static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
44 const Location loc,
45 ConversionPatternRewriter &rewriter,
46 StringRef name,
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
50}
51
52std::pair<Value, Value> getRawPtrAndSize(const Location loc,
53 ConversionPatternRewriter &rewriter,
54 Value memRef, Type elType) {
55 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
56 Value dataPtr =
57 LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58 Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
59 rewriter.getI64Type(), memRef, 2);
60 Value resPtr =
61 LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
62 Value size;
63 if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
64 size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
65 ArrayRef<int64_t>{3, 0});
66 size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
67 } else {
68 size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
69 }
70 return {resPtr, size};
71}
72
73/// When lowering the mpi dialect to functions calls certain details
74/// differ between various MPI implementations. This class will provide
75/// these in a generic way, depending on the MPI implementation that got
76/// selected by the DLTI attribute on the module.
77class MPIImplTraits {
78 ModuleOp &moduleOp;
79
80public:
81 /// Instantiate a new MPIImplTraits object according to the DLTI attribute
82 /// on the given module. Default to MPICH if no attribute is present or
83 /// the value is unknown.
84 static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
85
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
87
88 virtual ~MPIImplTraits() = default;
89
90 ModuleOp &getModuleOp() { return moduleOp; }
91
92 /// Gets or creates MPI_COMM_WORLD as a Value.
93 /// Different MPI implementations have different communicator types.
94 /// Using i64 as a portable, intermediate type.
95 /// Appropriate cast needs to take place before calling MPI functions.
96 virtual Value getCommWorld(Location loc,
97 ConversionPatternRewriter &rewriter) = 0;
98
99 /// Type converter provides i64 type for communicator type.
100 /// Converts to native type, which might be ptr or int or whatever.
101 virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
102 Value comm) = 0;
103
104 /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
105 virtual intptr_t getStatusIgnore() = 0;
106
107 /// Get the MPI_IN_PLACE value (void *).
108 virtual void *getInPlace() = 0;
109
110 /// Gets or creates an MPI datatype as a value which corresponds to the given
111 /// type.
112 virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
113 Type type) = 0;
114
115 /// Gets or creates an MPI_Op value which corresponds to the given
116 /// enum value.
117 virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
118 mpi::MPI_ReductionOpEnum opAttr) = 0;
119};
120
121//===----------------------------------------------------------------------===//
122// Implementation details for MPICH ABI compatible MPI implementations
123//===----------------------------------------------------------------------===//
124
125class MPICHImplTraits : public MPIImplTraits {
126 static constexpr int MPI_FLOAT = 0x4c00040a;
127 static constexpr int MPI_DOUBLE = 0x4c00080b;
128 static constexpr int MPI_INT8_T = 0x4c000137;
129 static constexpr int MPI_INT16_T = 0x4c000238;
130 static constexpr int MPI_INT32_T = 0x4c000439;
131 static constexpr int MPI_INT64_T = 0x4c00083a;
132 static constexpr int MPI_UINT8_T = 0x4c00013b;
133 static constexpr int MPI_UINT16_T = 0x4c00023c;
134 static constexpr int MPI_UINT32_T = 0x4c00043d;
135 static constexpr int MPI_UINT64_T = 0x4c00083e;
136 static constexpr int MPI_MAX = 0x58000001;
137 static constexpr int MPI_MIN = 0x58000002;
138 static constexpr int MPI_SUM = 0x58000003;
139 static constexpr int MPI_PROD = 0x58000004;
140 static constexpr int MPI_LAND = 0x58000005;
141 static constexpr int MPI_BAND = 0x58000006;
142 static constexpr int MPI_LOR = 0x58000007;
143 static constexpr int MPI_BOR = 0x58000008;
144 static constexpr int MPI_LXOR = 0x58000009;
145 static constexpr int MPI_BXOR = 0x5800000a;
146 static constexpr int MPI_MINLOC = 0x5800000b;
147 static constexpr int MPI_MAXLOC = 0x5800000c;
148 static constexpr int MPI_REPLACE = 0x5800000d;
149 static constexpr int MPI_NO_OP = 0x5800000e;
150
151public:
152 using MPIImplTraits::MPIImplTraits;
153
154 ~MPICHImplTraits() override = default;
155
156 Value getCommWorld(const Location loc,
157 ConversionPatternRewriter &rewriter) override {
158 static constexpr int MPI_COMM_WORLD = 0x44000000;
159 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
160 MPI_COMM_WORLD);
161 }
162
163 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
164 Value comm) override {
165 return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
166 }
167
168 intptr_t getStatusIgnore() override { return 1; }
169
170 void *getInPlace() override { return reinterpret_cast<void *>(-1); }
171
172 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
173 Type type) override {
174 int32_t mtype = 0;
175 if (type.isF32())
176 mtype = MPI_FLOAT;
177 else if (type.isF64())
178 mtype = MPI_DOUBLE;
179 else if (type.isInteger(64) && !type.isUnsignedInteger())
180 mtype = MPI_INT64_T;
181 else if (type.isInteger(64))
182 mtype = MPI_UINT64_T;
183 else if (type.isInteger(32) && !type.isUnsignedInteger())
184 mtype = MPI_INT32_T;
185 else if (type.isInteger(32))
186 mtype = MPI_UINT32_T;
187 else if (type.isInteger(16) && !type.isUnsignedInteger())
188 mtype = MPI_INT16_T;
189 else if (type.isInteger(16))
190 mtype = MPI_UINT16_T;
191 else if (type.isInteger(8) && !type.isUnsignedInteger())
192 mtype = MPI_INT8_T;
193 else if (type.isInteger(8))
194 mtype = MPI_UINT8_T;
195 else
196 assert(false && "unsupported type");
197 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
198 mtype);
199 }
200
201 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
202 mpi::MPI_ReductionOpEnum opAttr) override {
203 int32_t op = MPI_NO_OP;
204 switch (opAttr) {
205 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
206 op = MPI_NO_OP;
207 break;
208 case mpi::MPI_ReductionOpEnum::MPI_MAX:
209 op = MPI_MAX;
210 break;
211 case mpi::MPI_ReductionOpEnum::MPI_MIN:
212 op = MPI_MIN;
213 break;
214 case mpi::MPI_ReductionOpEnum::MPI_SUM:
215 op = MPI_SUM;
216 break;
217 case mpi::MPI_ReductionOpEnum::MPI_PROD:
218 op = MPI_PROD;
219 break;
220 case mpi::MPI_ReductionOpEnum::MPI_LAND:
221 op = MPI_LAND;
222 break;
223 case mpi::MPI_ReductionOpEnum::MPI_BAND:
224 op = MPI_BAND;
225 break;
226 case mpi::MPI_ReductionOpEnum::MPI_LOR:
227 op = MPI_LOR;
228 break;
229 case mpi::MPI_ReductionOpEnum::MPI_BOR:
230 op = MPI_BOR;
231 break;
232 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
233 op = MPI_LXOR;
234 break;
235 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
236 op = MPI_BXOR;
237 break;
238 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
239 op = MPI_MINLOC;
240 break;
241 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
242 op = MPI_MAXLOC;
243 break;
244 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
245 op = MPI_REPLACE;
246 break;
247 }
248 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
249 }
250};
251
252//===----------------------------------------------------------------------===//
253// Implementation details for OpenMPI
254//===----------------------------------------------------------------------===//
255class OMPIImplTraits : public MPIImplTraits {
256 LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
257 ConversionPatternRewriter &rewriter,
258 StringRef name,
259 LLVM::LLVMStructType type) {
260
261 return getOrDefineGlobal<LLVM::GlobalOp>(
262 getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
263 LLVM::Linkage::External, name,
264 /*value=*/Attribute(), /*alignment=*/0, 0);
265 }
266
267public:
268 using MPIImplTraits::MPIImplTraits;
269
270 ~OMPIImplTraits() override = default;
271
272 Value getCommWorld(const Location loc,
273 ConversionPatternRewriter &rewriter) override {
274 auto *context = rewriter.getContext();
275 // get external opaque struct pointer type
276 auto commStructT =
277 LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
278 StringRef name = "ompi_mpi_comm_world";
279
280 // make sure global op definition exists
281 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
282
283 // get address of symbol
284 auto comm = LLVM::AddressOfOp::create(rewriter, loc,
285 LLVM::LLVMPointerType::get(context),
286 SymbolRefAttr::get(context, name));
287 return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
288 }
289
290 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
291 Value comm) override {
292 return LLVM::IntToPtrOp::create(
293 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
294 }
295
296 intptr_t getStatusIgnore() override { return 0; }
297
298 void *getInPlace() override { return reinterpret_cast<void *>(1); }
299
300 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
301 Type type) override {
302 StringRef mtype;
303 if (type.isF32())
304 mtype = "ompi_mpi_float";
305 else if (type.isF64())
306 mtype = "ompi_mpi_double";
307 else if (type.isInteger(64) && !type.isUnsignedInteger())
308 mtype = "ompi_mpi_int64_t";
309 else if (type.isInteger(64))
310 mtype = "ompi_mpi_uint64_t";
311 else if (type.isInteger(32) && !type.isUnsignedInteger())
312 mtype = "ompi_mpi_int32_t";
313 else if (type.isInteger(32))
314 mtype = "ompi_mpi_uint32_t";
315 else if (type.isInteger(16) && !type.isUnsignedInteger())
316 mtype = "ompi_mpi_int16_t";
317 else if (type.isInteger(16))
318 mtype = "ompi_mpi_uint16_t";
319 else if (type.isInteger(8) && !type.isUnsignedInteger())
320 mtype = "ompi_mpi_int8_t";
321 else if (type.isInteger(8))
322 mtype = "ompi_mpi_uint8_t";
323 else
324 assert(false && "unsupported type");
325
326 auto *context = rewriter.getContext();
327 // get external opaque struct pointer type
328 auto typeStructT =
329 LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
330 // make sure global op definition exists
331 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
332 // get address of symbol
333 return LLVM::AddressOfOp::create(rewriter, loc,
334 LLVM::LLVMPointerType::get(context),
335 SymbolRefAttr::get(context, mtype));
336 }
337
338 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
339 mpi::MPI_ReductionOpEnum opAttr) override {
340 StringRef op;
341 switch (opAttr) {
342 case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
343 op = "ompi_mpi_no_op";
344 break;
345 case mpi::MPI_ReductionOpEnum::MPI_MAX:
346 op = "ompi_mpi_max";
347 break;
348 case mpi::MPI_ReductionOpEnum::MPI_MIN:
349 op = "ompi_mpi_min";
350 break;
351 case mpi::MPI_ReductionOpEnum::MPI_SUM:
352 op = "ompi_mpi_sum";
353 break;
354 case mpi::MPI_ReductionOpEnum::MPI_PROD:
355 op = "ompi_mpi_prod";
356 break;
357 case mpi::MPI_ReductionOpEnum::MPI_LAND:
358 op = "ompi_mpi_land";
359 break;
360 case mpi::MPI_ReductionOpEnum::MPI_BAND:
361 op = "ompi_mpi_band";
362 break;
363 case mpi::MPI_ReductionOpEnum::MPI_LOR:
364 op = "ompi_mpi_lor";
365 break;
366 case mpi::MPI_ReductionOpEnum::MPI_BOR:
367 op = "ompi_mpi_bor";
368 break;
369 case mpi::MPI_ReductionOpEnum::MPI_LXOR:
370 op = "ompi_mpi_lxor";
371 break;
372 case mpi::MPI_ReductionOpEnum::MPI_BXOR:
373 op = "ompi_mpi_bxor";
374 break;
375 case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
376 op = "ompi_mpi_minloc";
377 break;
378 case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
379 op = "ompi_mpi_maxloc";
380 break;
381 case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
382 op = "ompi_mpi_replace";
383 break;
384 }
385 auto *context = rewriter.getContext();
386 // get external opaque struct pointer type
387 auto opStructT =
388 LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
389 // make sure global op definition exists
390 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
391 // get address of symbol
392 return LLVM::AddressOfOp::create(rewriter, loc,
393 LLVM::LLVMPointerType::get(context),
394 SymbolRefAttr::get(context, op));
395 }
396};
397
398std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
399 auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
400 if (failed(attr))
401 return std::make_unique<MPICHImplTraits>(moduleOp);
402 auto strAttr = dyn_cast<StringAttr>(attr.value());
403 if (strAttr && strAttr.getValue() == "OpenMPI")
404 return std::make_unique<OMPIImplTraits>(moduleOp);
405 if (!strAttr || strAttr.getValue() != "MPICH")
406 moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
407 << (strAttr ? strAttr.getValue() : "<NULL>")
408 << "), defaulting to MPICH";
409 return std::make_unique<MPICHImplTraits>(moduleOp);
410}
411
412//===----------------------------------------------------------------------===//
413// InitOpLowering
414//===----------------------------------------------------------------------===//
415
416struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
418
419 LogicalResult
420 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422 Location loc = op.getLoc();
423
424 // ptrType `!llvm.ptr`
425 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
426
427 // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
428 auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
429 Value llvmnull = nullPtrOp.getRes();
430
431 // grab a reference to the global module op:
432 auto moduleOp = op->getParentOfType<ModuleOp>();
433
434 // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
435 auto initFuncType =
436 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
437 // get or create function declaration:
438 LLVM::LLVMFuncOp initDecl =
439 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
440
441 // replace init with function call
442 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
443 ValueRange{llvmnull, llvmnull});
444
445 return success();
446 }
447};
448
449//===----------------------------------------------------------------------===//
450// FinalizeOpLowering
451//===----------------------------------------------------------------------===//
452
453struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
455
456 LogicalResult
457 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
458 ConversionPatternRewriter &rewriter) const override {
459 // get loc
460 Location loc = op.getLoc();
461
462 // grab a reference to the global module op:
463 auto moduleOp = op->getParentOfType<ModuleOp>();
464
465 // LLVM Function type representing `i32 MPI_Finalize()`
466 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
467 // get or create function declaration:
468 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
469 moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
470
471 // replace init with function call
472 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
473
474 return success();
475 }
476};
477
478//===----------------------------------------------------------------------===//
479// CommWorldOpLowering
480//===----------------------------------------------------------------------===//
481
482struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
484
485 LogicalResult
486 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
487 ConversionPatternRewriter &rewriter) const override {
488 // grab a reference to the global module op:
489 auto moduleOp = op->getParentOfType<ModuleOp>();
490 auto mpiTraits = MPIImplTraits::get(moduleOp);
491 // get MPI_COMM_WORLD
492 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
493
494 return success();
495 }
496};
497
498//===----------------------------------------------------------------------===//
499// CommSplitOpLowering
500//===----------------------------------------------------------------------===//
501
502struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
504
505 LogicalResult
506 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
507 ConversionPatternRewriter &rewriter) const override {
508 // grab a reference to the global module op:
509 auto moduleOp = op->getParentOfType<ModuleOp>();
510 auto mpiTraits = MPIImplTraits::get(moduleOp);
511 Type i32 = rewriter.getI32Type();
512 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
513 Location loc = op.getLoc();
514
515 // get communicator
516 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
517 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
518 auto outPtr =
519 LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
520
521 // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
522 auto funcType =
523 LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
524 // get or create function declaration:
525 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
526 "MPI_Comm_split", funcType);
527
528 auto callOp =
529 LLVM::CallOp::create(rewriter, loc, funcDecl,
530 ValueRange{comm, adaptor.getColor(),
531 adaptor.getKey(), outPtr.getRes()});
532
533 // load the communicator into a register
534 Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
535 res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
536
537 // if retval is checked, replace uses of retval with the results from the
538 // call op
539 SmallVector<Value> replacements;
540 if (op.getRetval())
541 replacements.push_back(callOp.getResult());
542
543 // replace op
544 replacements.push_back(res);
545 rewriter.replaceOp(op, replacements);
546
547 return success();
548 }
549};
550
551//===----------------------------------------------------------------------===//
552// CommRankOpLowering
553//===----------------------------------------------------------------------===//
554
555struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
557
558 LogicalResult
559 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561 // get some helper vars
562 Location loc = op.getLoc();
563 MLIRContext *context = rewriter.getContext();
564 Type i32 = rewriter.getI32Type();
565
566 // ptrType `!llvm.ptr`
567 Type ptrType = LLVM::LLVMPointerType::get(context);
568
569 // grab a reference to the global module op:
570 auto moduleOp = op->getParentOfType<ModuleOp>();
571
572 auto mpiTraits = MPIImplTraits::get(moduleOp);
573 // get communicator
574 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
575
576 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
577 auto rankFuncType =
578 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
579 // get or create function declaration:
580 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
581 moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
582
583 // replace with function call
584 auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
585 auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
586 auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
587 ValueRange{comm, rankptr.getRes()});
588
589 // load the rank into a register
590 auto loadedRank =
591 LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
592
593 // if retval is checked, replace uses of retval with the results from the
594 // call op
595 SmallVector<Value> replacements;
596 if (op.getRetval())
597 replacements.push_back(callOp.getResult());
598
599 // replace all uses, then erase op
600 replacements.push_back(loadedRank.getRes());
601 rewriter.replaceOp(op, replacements);
602
603 return success();
604 }
605};
606
607//===----------------------------------------------------------------------===//
608// SendOpLowering
609//===----------------------------------------------------------------------===//
610
611struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
613
614 LogicalResult
615 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter) const override {
617 // get some helper vars
618 Location loc = op.getLoc();
619 MLIRContext *context = rewriter.getContext();
620 Type i32 = rewriter.getI32Type();
621 Type elemType = op.getRef().getType().getElementType();
622
623 // ptrType `!llvm.ptr`
624 Type ptrType = LLVM::LLVMPointerType::get(context);
625
626 // grab a reference to the global module op:
627 auto moduleOp = op->getParentOfType<ModuleOp>();
628
629 // get MPI_COMM_WORLD, dataType and pointer
630 auto [dataPtr, size] =
631 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
632 auto mpiTraits = MPIImplTraits::get(moduleOp);
633 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
634 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
635
636 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
637 // tag, comm)`
638 auto funcType = LLVM::LLVMFunctionType::get(
639 i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
640 // get or create function declaration:
641 LLVM::LLVMFuncOp funcDecl =
642 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
643
644 // replace op with function call
645 auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
646 ValueRange{dataPtr, size, dataType,
647 adaptor.getDest(),
648 adaptor.getTag(), comm});
649 if (op.getRetval())
650 rewriter.replaceOp(op, funcCall.getResult());
651 else
652 rewriter.eraseOp(op);
653
654 return success();
655 }
656};
657
658//===----------------------------------------------------------------------===//
659// RecvOpLowering
660//===----------------------------------------------------------------------===//
661
662struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
664
665 LogicalResult
666 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668 // get some helper vars
669 Location loc = op.getLoc();
670 MLIRContext *context = rewriter.getContext();
671 Type i32 = rewriter.getI32Type();
672 Type i64 = rewriter.getI64Type();
673 Type elemType = op.getRef().getType().getElementType();
674
675 // ptrType `!llvm.ptr`
676 Type ptrType = LLVM::LLVMPointerType::get(context);
677
678 // grab a reference to the global module op:
679 auto moduleOp = op->getParentOfType<ModuleOp>();
680
681 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
682 auto [dataPtr, size] =
683 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
684 auto mpiTraits = MPIImplTraits::get(moduleOp);
685 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
686 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
687 Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
688 mpiTraits->getStatusIgnore());
689 statusIgnore =
690 LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
691
692 // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
693 // tag, comm)`
694 auto funcType =
695 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
696 i32, comm.getType(), ptrType});
697 // get or create function declaration:
698 LLVM::LLVMFuncOp funcDecl =
699 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
700
701 // replace op with function call
702 auto funcCall = LLVM::CallOp::create(
703 rewriter, loc, funcDecl,
704 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
705 adaptor.getTag(), comm, statusIgnore});
706 if (op.getRetval())
707 rewriter.replaceOp(op, funcCall.getResult());
708 else
709 rewriter.eraseOp(op);
710
711 return success();
712 }
713};
714
715//===----------------------------------------------------------------------===//
716// AllReduceOpLowering
717//===----------------------------------------------------------------------===//
718
719struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
721
722 LogicalResult
723 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
724 ConversionPatternRewriter &rewriter) const override {
725 Location loc = op.getLoc();
726 MLIRContext *context = rewriter.getContext();
727 Type i32 = rewriter.getI32Type();
728 Type i64 = rewriter.getI64Type();
729 Type elemType = op.getSendbuf().getType().getElementType();
730
731 // ptrType `!llvm.ptr`
732 Type ptrType = LLVM::LLVMPointerType::get(context);
733 auto moduleOp = op->getParentOfType<ModuleOp>();
734 auto mpiTraits = MPIImplTraits::get(moduleOp);
735 auto [sendPtr, sendSize] =
736 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
737 auto [recvPtr, recvSize] =
738 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
739
740 // If input and output are the same, request in-place operation.
741 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
742 sendPtr = LLVM::ConstantOp::create(
743 rewriter, loc, i64,
744 reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
745 sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
746 }
747
748 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
749 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
750 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
751
752 // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
753 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
754 auto funcType = LLVM::LLVMFunctionType::get(
755 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
756 commWorld.getType()});
757 // get or create function declaration:
758 LLVM::LLVMFuncOp funcDecl =
759 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
760
761 // replace op with function call
762 auto funcCall = LLVM::CallOp::create(
763 rewriter, loc, funcDecl,
764 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
765
766 if (op.getRetval())
767 rewriter.replaceOp(op, funcCall.getResult());
768 else
769 rewriter.eraseOp(op);
770
771 return success();
772 }
773};
774
775//===----------------------------------------------------------------------===//
776// ConvertToLLVMPatternInterface implementation
777//===----------------------------------------------------------------------===//
778
779/// Implement the interface to convert Func to LLVM.
780struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
782 /// Hook for derived dialect interface to provide conversion patterns
783 /// and mark dialect legal for the conversion target.
784 void populateConvertToLLVMConversionPatterns(
785 ConversionTarget &target, LLVMTypeConverter &typeConverter,
786 RewritePatternSet &patterns) const final {
788 }
789};
790} // namespace
791
792//===----------------------------------------------------------------------===//
793// Pattern Population
794//===----------------------------------------------------------------------===//
795
798 // Using i64 as a portable, intermediate type for !mpi.comm.
799 // It would be nicer to somehow get the right type directly, but TLDI is not
800 // available here.
801 converter.addConversion([](mpi::CommType type) {
802 return IntegerType::get(type.getContext(), 64);
803 });
804 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
805 FinalizeOpLowering, InitOpLowering, SendOpLowering,
806 RecvOpLowering, AllReduceOpLowering>(converter);
807}
808
810 registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
811 dialect->addInterfaces<FuncToLLVMDialectInterface>();
812 });
813}
return success()
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This provides public APIs that all operations should have.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition DLTI.cpp:537
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Note that these functions don't take a SymbolTable because GPU module lowerings can have name collisi...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...