MLIR  22.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1 //===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //
10 // Copyright (C) by Argonne National Laboratory
11 // See COPYRIGHT in top-level directory
12 // of MPICH source repository.
13 //
14 
19 #include "mlir/Dialect/DLTI/DLTI.h"
24 #include <memory>
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 template <typename Op, typename... Args>
31 static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
32  ConversionPatternRewriter &rewriter, StringRef name,
33  Args &&...args) {
34  Op ret;
35  if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
36  ConversionPatternRewriter::InsertionGuard guard(rewriter);
37  rewriter.setInsertionPointToStart(moduleOp.getBody());
38  ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
39  }
40  return ret;
41 }
42 
43 static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
44  const Location loc,
45  ConversionPatternRewriter &rewriter,
46  StringRef name,
47  LLVM::LLVMFunctionType type) {
48  return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49  moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
50 }
51 
52 std::pair<Value, Value> getRawPtrAndSize(const Location loc,
53  ConversionPatternRewriter &rewriter,
54  Value memRef, Type elType) {
55  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
56  Value dataPtr =
57  LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
58  Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
59  rewriter.getI64Type(), memRef, 2);
60  Value resPtr =
61  LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
62  Value size;
63  if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
64  size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
65  ArrayRef<int64_t>{3, 0});
66  size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
67  } else {
68  size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
69  }
70  return {resPtr, size};
71 }
72 
73 /// When lowering the mpi dialect to functions calls certain details
74 /// differ between various MPI implementations. This class will provide
75 /// these in a generic way, depending on the MPI implementation that got
76 /// selected by the DLTI attribute on the module.
77 class MPIImplTraits {
78  ModuleOp &moduleOp;
79 
80 public:
81  /// Instantiate a new MPIImplTraits object according to the DLTI attribute
82  /// on the given module. Default to MPICH if no attribute is present or
83  /// the value is unknown.
84  static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
85 
86  explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
87 
88  virtual ~MPIImplTraits() = default;
89 
90  ModuleOp &getModuleOp() { return moduleOp; }
91 
92  /// Gets or creates MPI_COMM_WORLD as a Value.
93  /// Different MPI implementations have different communicator types.
94  /// Using i64 as a portable, intermediate type.
95  /// Appropriate cast needs to take place before calling MPI functions.
96  virtual Value getCommWorld(const Location loc,
97  ConversionPatternRewriter &rewriter) = 0;
98 
99  /// Type converter provides i64 type for communicator type.
100  /// Converts to native type, which might be ptr or int or whatever.
101  virtual Value castComm(const Location loc,
102  ConversionPatternRewriter &rewriter, Value comm) = 0;
103 
104  /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
105  virtual intptr_t getStatusIgnore() = 0;
106 
107  /// Get the MPI_IN_PLACE value (void *).
108  virtual void *getInPlace() = 0;
109 
110  /// Gets or creates an MPI datatype as a value which corresponds to the given
111  /// type.
112  virtual Value getDataType(const Location loc,
113  ConversionPatternRewriter &rewriter, Type type) = 0;
114 
115  /// Gets or creates an MPI_Op value which corresponds to the given
116  /// enum value.
117  virtual Value getMPIOp(const Location loc,
118  ConversionPatternRewriter &rewriter,
119  mpi::MPI_ReductionOpEnum opAttr) = 0;
120 };
121 
122 //===----------------------------------------------------------------------===//
123 // Implementation details for MPICH ABI compatible MPI implementations
124 //===----------------------------------------------------------------------===//
125 
126 class MPICHImplTraits : public MPIImplTraits {
127  static constexpr int MPI_FLOAT = 0x4c00040a;
128  static constexpr int MPI_DOUBLE = 0x4c00080b;
129  static constexpr int MPI_INT8_T = 0x4c000137;
130  static constexpr int MPI_INT16_T = 0x4c000238;
131  static constexpr int MPI_INT32_T = 0x4c000439;
132  static constexpr int MPI_INT64_T = 0x4c00083a;
133  static constexpr int MPI_UINT8_T = 0x4c00013b;
134  static constexpr int MPI_UINT16_T = 0x4c00023c;
135  static constexpr int MPI_UINT32_T = 0x4c00043d;
136  static constexpr int MPI_UINT64_T = 0x4c00083e;
137  static constexpr int MPI_MAX = 0x58000001;
138  static constexpr int MPI_MIN = 0x58000002;
139  static constexpr int MPI_SUM = 0x58000003;
140  static constexpr int MPI_PROD = 0x58000004;
141  static constexpr int MPI_LAND = 0x58000005;
142  static constexpr int MPI_BAND = 0x58000006;
143  static constexpr int MPI_LOR = 0x58000007;
144  static constexpr int MPI_BOR = 0x58000008;
145  static constexpr int MPI_LXOR = 0x58000009;
146  static constexpr int MPI_BXOR = 0x5800000a;
147  static constexpr int MPI_MINLOC = 0x5800000b;
148  static constexpr int MPI_MAXLOC = 0x5800000c;
149  static constexpr int MPI_REPLACE = 0x5800000d;
150  static constexpr int MPI_NO_OP = 0x5800000e;
151 
152 public:
153  using MPIImplTraits::MPIImplTraits;
154 
155  ~MPICHImplTraits() override = default;
156 
157  Value getCommWorld(const Location loc,
158  ConversionPatternRewriter &rewriter) override {
159  static constexpr int MPI_COMM_WORLD = 0x44000000;
160  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
161  MPI_COMM_WORLD);
162  }
163 
164  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
165  Value comm) override {
166  return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
167  }
168 
169  intptr_t getStatusIgnore() override { return 1; }
170 
171  void *getInPlace() override { return reinterpret_cast<void *>(-1); }
172 
173  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
174  Type type) override {
175  int32_t mtype = 0;
176  if (type.isF32())
177  mtype = MPI_FLOAT;
178  else if (type.isF64())
179  mtype = MPI_DOUBLE;
180  else if (type.isInteger(64) && !type.isUnsignedInteger())
181  mtype = MPI_INT64_T;
182  else if (type.isInteger(64))
183  mtype = MPI_UINT64_T;
184  else if (type.isInteger(32) && !type.isUnsignedInteger())
185  mtype = MPI_INT32_T;
186  else if (type.isInteger(32))
187  mtype = MPI_UINT32_T;
188  else if (type.isInteger(16) && !type.isUnsignedInteger())
189  mtype = MPI_INT16_T;
190  else if (type.isInteger(16))
191  mtype = MPI_UINT16_T;
192  else if (type.isInteger(8) && !type.isUnsignedInteger())
193  mtype = MPI_INT8_T;
194  else if (type.isInteger(8))
195  mtype = MPI_UINT8_T;
196  else
197  assert(false && "unsupported type");
198  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
199  mtype);
200  }
201 
202  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
203  mpi::MPI_ReductionOpEnum opAttr) override {
204  int32_t op = MPI_NO_OP;
205  switch (opAttr) {
206  case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
207  op = MPI_NO_OP;
208  break;
209  case mpi::MPI_ReductionOpEnum::MPI_MAX:
210  op = MPI_MAX;
211  break;
212  case mpi::MPI_ReductionOpEnum::MPI_MIN:
213  op = MPI_MIN;
214  break;
215  case mpi::MPI_ReductionOpEnum::MPI_SUM:
216  op = MPI_SUM;
217  break;
218  case mpi::MPI_ReductionOpEnum::MPI_PROD:
219  op = MPI_PROD;
220  break;
221  case mpi::MPI_ReductionOpEnum::MPI_LAND:
222  op = MPI_LAND;
223  break;
224  case mpi::MPI_ReductionOpEnum::MPI_BAND:
225  op = MPI_BAND;
226  break;
227  case mpi::MPI_ReductionOpEnum::MPI_LOR:
228  op = MPI_LOR;
229  break;
230  case mpi::MPI_ReductionOpEnum::MPI_BOR:
231  op = MPI_BOR;
232  break;
233  case mpi::MPI_ReductionOpEnum::MPI_LXOR:
234  op = MPI_LXOR;
235  break;
236  case mpi::MPI_ReductionOpEnum::MPI_BXOR:
237  op = MPI_BXOR;
238  break;
239  case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
240  op = MPI_MINLOC;
241  break;
242  case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
243  op = MPI_MAXLOC;
244  break;
245  case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
246  op = MPI_REPLACE;
247  break;
248  }
249  return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
250  }
251 };
252 
253 //===----------------------------------------------------------------------===//
254 // Implementation details for OpenMPI
255 //===----------------------------------------------------------------------===//
256 class OMPIImplTraits : public MPIImplTraits {
257  LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
258  ConversionPatternRewriter &rewriter,
259  StringRef name,
260  LLVM::LLVMStructType type) {
261 
262  return getOrDefineGlobal<LLVM::GlobalOp>(
263  getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
264  LLVM::Linkage::External, name,
265  /*value=*/Attribute(), /*alignment=*/0, 0);
266  }
267 
268 public:
269  using MPIImplTraits::MPIImplTraits;
270 
271  ~OMPIImplTraits() override = default;
272 
273  Value getCommWorld(const Location loc,
274  ConversionPatternRewriter &rewriter) override {
275  auto *context = rewriter.getContext();
276  // get external opaque struct pointer type
277  auto commStructT =
278  LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
279  StringRef name = "ompi_mpi_comm_world";
280 
281  // make sure global op definition exists
282  getOrDefineExternalStruct(loc, rewriter, name, commStructT);
283 
284  // get address of symbol
285  auto comm = LLVM::AddressOfOp::create(rewriter, loc,
287  SymbolRefAttr::get(context, name));
288  return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
289  }
290 
291  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
292  Value comm) override {
293  return LLVM::IntToPtrOp::create(
294  rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
295  }
296 
297  intptr_t getStatusIgnore() override { return 0; }
298 
299  void *getInPlace() override { return reinterpret_cast<void *>(1); }
300 
301  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
302  Type type) override {
303  StringRef mtype;
304  if (type.isF32())
305  mtype = "ompi_mpi_float";
306  else if (type.isF64())
307  mtype = "ompi_mpi_double";
308  else if (type.isInteger(64) && !type.isUnsignedInteger())
309  mtype = "ompi_mpi_int64_t";
310  else if (type.isInteger(64))
311  mtype = "ompi_mpi_uint64_t";
312  else if (type.isInteger(32) && !type.isUnsignedInteger())
313  mtype = "ompi_mpi_int32_t";
314  else if (type.isInteger(32))
315  mtype = "ompi_mpi_uint32_t";
316  else if (type.isInteger(16) && !type.isUnsignedInteger())
317  mtype = "ompi_mpi_int16_t";
318  else if (type.isInteger(16))
319  mtype = "ompi_mpi_uint16_t";
320  else if (type.isInteger(8) && !type.isUnsignedInteger())
321  mtype = "ompi_mpi_int8_t";
322  else if (type.isInteger(8))
323  mtype = "ompi_mpi_uint8_t";
324  else
325  assert(false && "unsupported type");
326 
327  auto *context = rewriter.getContext();
328  // get external opaque struct pointer type
329  auto typeStructT =
330  LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
331  // make sure global op definition exists
332  getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
333  // get address of symbol
334  return LLVM::AddressOfOp::create(rewriter, loc,
336  SymbolRefAttr::get(context, mtype));
337  }
338 
339  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
340  mpi::MPI_ReductionOpEnum opAttr) override {
341  StringRef op;
342  switch (opAttr) {
343  case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
344  op = "ompi_mpi_no_op";
345  break;
346  case mpi::MPI_ReductionOpEnum::MPI_MAX:
347  op = "ompi_mpi_max";
348  break;
349  case mpi::MPI_ReductionOpEnum::MPI_MIN:
350  op = "ompi_mpi_min";
351  break;
352  case mpi::MPI_ReductionOpEnum::MPI_SUM:
353  op = "ompi_mpi_sum";
354  break;
355  case mpi::MPI_ReductionOpEnum::MPI_PROD:
356  op = "ompi_mpi_prod";
357  break;
358  case mpi::MPI_ReductionOpEnum::MPI_LAND:
359  op = "ompi_mpi_land";
360  break;
361  case mpi::MPI_ReductionOpEnum::MPI_BAND:
362  op = "ompi_mpi_band";
363  break;
364  case mpi::MPI_ReductionOpEnum::MPI_LOR:
365  op = "ompi_mpi_lor";
366  break;
367  case mpi::MPI_ReductionOpEnum::MPI_BOR:
368  op = "ompi_mpi_bor";
369  break;
370  case mpi::MPI_ReductionOpEnum::MPI_LXOR:
371  op = "ompi_mpi_lxor";
372  break;
373  case mpi::MPI_ReductionOpEnum::MPI_BXOR:
374  op = "ompi_mpi_bxor";
375  break;
376  case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
377  op = "ompi_mpi_minloc";
378  break;
379  case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
380  op = "ompi_mpi_maxloc";
381  break;
382  case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
383  op = "ompi_mpi_replace";
384  break;
385  }
386  auto *context = rewriter.getContext();
387  // get external opaque struct pointer type
388  auto opStructT =
389  LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
390  // make sure global op definition exists
391  getOrDefineExternalStruct(loc, rewriter, op, opStructT);
392  // get address of symbol
393  return LLVM::AddressOfOp::create(rewriter, loc,
395  SymbolRefAttr::get(context, op));
396  }
397 };
398 
399 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
400  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
401  if (failed(attr))
402  return std::make_unique<MPICHImplTraits>(moduleOp);
403  auto strAttr = dyn_cast<StringAttr>(attr.value());
404  if (strAttr && strAttr.getValue() == "OpenMPI")
405  return std::make_unique<OMPIImplTraits>(moduleOp);
406  if (!strAttr || strAttr.getValue() != "MPICH")
407  moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
408  << (strAttr ? strAttr.getValue() : "<NULL>")
409  << "), defaulting to MPICH";
410  return std::make_unique<MPICHImplTraits>(moduleOp);
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // InitOpLowering
415 //===----------------------------------------------------------------------===//
416 
417 struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
419 
420  LogicalResult
421  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
422  ConversionPatternRewriter &rewriter) const override {
423  Location loc = op.getLoc();
424 
425  // ptrType `!llvm.ptr`
426  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
427 
428  // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
429  auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
430  Value llvmnull = nullPtrOp.getRes();
431 
432  // grab a reference to the global module op:
433  auto moduleOp = op->getParentOfType<ModuleOp>();
434 
435  // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
436  auto initFuncType =
437  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
438  // get or create function declaration:
439  LLVM::LLVMFuncOp initDecl =
440  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
441 
442  // replace init with function call
443  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
444  ValueRange{llvmnull, llvmnull});
445 
446  return success();
447  }
448 };
449 
450 //===----------------------------------------------------------------------===//
451 // FinalizeOpLowering
452 //===----------------------------------------------------------------------===//
453 
454 struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
456 
457  LogicalResult
458  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
459  ConversionPatternRewriter &rewriter) const override {
460  // get loc
461  Location loc = op.getLoc();
462 
463  // grab a reference to the global module op:
464  auto moduleOp = op->getParentOfType<ModuleOp>();
465 
466  // LLVM Function type representing `i32 MPI_Finalize()`
467  auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
468  // get or create function declaration:
469  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
470  moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
471 
472  // replace init with function call
473  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
474 
475  return success();
476  }
477 };
478 
479 //===----------------------------------------------------------------------===//
480 // CommWorldOpLowering
481 //===----------------------------------------------------------------------===//
482 
483 struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
485 
486  LogicalResult
487  matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
488  ConversionPatternRewriter &rewriter) const override {
489  // grab a reference to the global module op:
490  auto moduleOp = op->getParentOfType<ModuleOp>();
491  auto mpiTraits = MPIImplTraits::get(moduleOp);
492  // get MPI_COMM_WORLD
493  rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
494 
495  return success();
496  }
497 };
498 
499 //===----------------------------------------------------------------------===//
500 // CommSplitOpLowering
501 //===----------------------------------------------------------------------===//
502 
503 struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
505 
506  LogicalResult
507  matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
508  ConversionPatternRewriter &rewriter) const override {
509  // grab a reference to the global module op:
510  auto moduleOp = op->getParentOfType<ModuleOp>();
511  auto mpiTraits = MPIImplTraits::get(moduleOp);
512  Type i32 = rewriter.getI32Type();
513  Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
514  Location loc = op.getLoc();
515 
516  // get communicator
517  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
518  auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
519  auto outPtr =
520  LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
521 
522  // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
523  auto funcType =
524  LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
525  // get or create function declaration:
526  LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
527  "MPI_Comm_split", funcType);
528 
529  auto callOp =
530  LLVM::CallOp::create(rewriter, loc, funcDecl,
531  ValueRange{comm, adaptor.getColor(),
532  adaptor.getKey(), outPtr.getRes()});
533 
534  // load the communicator into a register
535  Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
536  res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
537 
538  // if retval is checked, replace uses of retval with the results from the
539  // call op
540  SmallVector<Value> replacements;
541  if (op.getRetval())
542  replacements.push_back(callOp.getResult());
543 
544  // replace op
545  replacements.push_back(res);
546  rewriter.replaceOp(op, replacements);
547 
548  return success();
549  }
550 };
551 
552 //===----------------------------------------------------------------------===//
553 // CommRankOpLowering
554 //===----------------------------------------------------------------------===//
555 
556 struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
558 
559  LogicalResult
560  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
561  ConversionPatternRewriter &rewriter) const override {
562  // get some helper vars
563  Location loc = op.getLoc();
564  MLIRContext *context = rewriter.getContext();
565  Type i32 = rewriter.getI32Type();
566 
567  // ptrType `!llvm.ptr`
568  Type ptrType = LLVM::LLVMPointerType::get(context);
569 
570  // grab a reference to the global module op:
571  auto moduleOp = op->getParentOfType<ModuleOp>();
572 
573  auto mpiTraits = MPIImplTraits::get(moduleOp);
574  // get communicator
575  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
576 
577  // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
578  auto rankFuncType =
579  LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
580  // get or create function declaration:
581  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
582  moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
583 
584  // replace with function call
585  auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
586  auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
587  auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
588  ValueRange{comm, rankptr.getRes()});
589 
590  // load the rank into a register
591  auto loadedRank =
592  LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
593 
594  // if retval is checked, replace uses of retval with the results from the
595  // call op
596  SmallVector<Value> replacements;
597  if (op.getRetval())
598  replacements.push_back(callOp.getResult());
599 
600  // replace all uses, then erase op
601  replacements.push_back(loadedRank.getRes());
602  rewriter.replaceOp(op, replacements);
603 
604  return success();
605  }
606 };
607 
608 //===----------------------------------------------------------------------===//
609 // SendOpLowering
610 //===----------------------------------------------------------------------===//
611 
612 struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
614 
615  LogicalResult
616  matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
617  ConversionPatternRewriter &rewriter) const override {
618  // get some helper vars
619  Location loc = op.getLoc();
620  MLIRContext *context = rewriter.getContext();
621  Type i32 = rewriter.getI32Type();
622  Type elemType = op.getRef().getType().getElementType();
623 
624  // ptrType `!llvm.ptr`
625  Type ptrType = LLVM::LLVMPointerType::get(context);
626 
627  // grab a reference to the global module op:
628  auto moduleOp = op->getParentOfType<ModuleOp>();
629 
630  // get MPI_COMM_WORLD, dataType and pointer
631  auto [dataPtr, size] =
632  getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
633  auto mpiTraits = MPIImplTraits::get(moduleOp);
634  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
635  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
636 
637  // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
638  // tag, comm)`
639  auto funcType = LLVM::LLVMFunctionType::get(
640  i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
641  // get or create function declaration:
642  LLVM::LLVMFuncOp funcDecl =
643  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
644 
645  // replace op with function call
646  auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
647  ValueRange{dataPtr, size, dataType,
648  adaptor.getDest(),
649  adaptor.getTag(), comm});
650  if (op.getRetval())
651  rewriter.replaceOp(op, funcCall.getResult());
652  else
653  rewriter.eraseOp(op);
654 
655  return success();
656  }
657 };
658 
659 //===----------------------------------------------------------------------===//
660 // RecvOpLowering
661 //===----------------------------------------------------------------------===//
662 
663 struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
665 
666  LogicalResult
667  matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const override {
669  // get some helper vars
670  Location loc = op.getLoc();
671  MLIRContext *context = rewriter.getContext();
672  Type i32 = rewriter.getI32Type();
673  Type i64 = rewriter.getI64Type();
674  Type elemType = op.getRef().getType().getElementType();
675 
676  // ptrType `!llvm.ptr`
677  Type ptrType = LLVM::LLVMPointerType::get(context);
678 
679  // grab a reference to the global module op:
680  auto moduleOp = op->getParentOfType<ModuleOp>();
681 
682  // get MPI_COMM_WORLD, dataType, status_ignore and pointer
683  auto [dataPtr, size] =
684  getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
685  auto mpiTraits = MPIImplTraits::get(moduleOp);
686  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
687  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
688  Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
689  mpiTraits->getStatusIgnore());
690  statusIgnore =
691  LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
692 
693  // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
694  // tag, comm)`
695  auto funcType =
696  LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
697  i32, comm.getType(), ptrType});
698  // get or create function declaration:
699  LLVM::LLVMFuncOp funcDecl =
700  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
701 
702  // replace op with function call
703  auto funcCall = LLVM::CallOp::create(
704  rewriter, loc, funcDecl,
705  ValueRange{dataPtr, size, dataType, adaptor.getSource(),
706  adaptor.getTag(), comm, statusIgnore});
707  if (op.getRetval())
708  rewriter.replaceOp(op, funcCall.getResult());
709  else
710  rewriter.eraseOp(op);
711 
712  return success();
713  }
714 };
715 
716 //===----------------------------------------------------------------------===//
717 // AllReduceOpLowering
718 //===----------------------------------------------------------------------===//
719 
720 struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
722 
723  LogicalResult
724  matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
725  ConversionPatternRewriter &rewriter) const override {
726  Location loc = op.getLoc();
727  MLIRContext *context = rewriter.getContext();
728  Type i32 = rewriter.getI32Type();
729  Type i64 = rewriter.getI64Type();
730  Type elemType = op.getSendbuf().getType().getElementType();
731 
732  // ptrType `!llvm.ptr`
733  Type ptrType = LLVM::LLVMPointerType::get(context);
734  auto moduleOp = op->getParentOfType<ModuleOp>();
735  auto mpiTraits = MPIImplTraits::get(moduleOp);
736  auto [sendPtr, sendSize] =
737  getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
738  auto [recvPtr, recvSize] =
739  getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
740 
741  // If input and output are the same, request in-place operation.
742  if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
743  sendPtr = LLVM::ConstantOp::create(
744  rewriter, loc, i64,
745  reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
746  sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
747  }
748 
749  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
750  Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
751  Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
752 
753  // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
754  // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
755  auto funcType = LLVM::LLVMFunctionType::get(
756  i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
757  commWorld.getType()});
758  // get or create function declaration:
759  LLVM::LLVMFuncOp funcDecl =
760  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
761 
762  // replace op with function call
763  auto funcCall = LLVM::CallOp::create(
764  rewriter, loc, funcDecl,
765  ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
766 
767  if (op.getRetval())
768  rewriter.replaceOp(op, funcCall.getResult());
769  else
770  rewriter.eraseOp(op);
771 
772  return success();
773  }
774 };
775 
776 //===----------------------------------------------------------------------===//
777 // ConvertToLLVMPatternInterface implementation
778 //===----------------------------------------------------------------------===//
779 
780 /// Implement the interface to convert Func to LLVM.
781 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
783  /// Hook for derived dialect interface to provide conversion patterns
784  /// and mark dialect legal for the conversion target.
785  void populateConvertToLLVMConversionPatterns(
786  ConversionTarget &target, LLVMTypeConverter &typeConverter,
787  RewritePatternSet &patterns) const final {
789  }
790 };
791 } // namespace
792 
793 //===----------------------------------------------------------------------===//
794 // Pattern Population
795 //===----------------------------------------------------------------------===//
796 
799  // Using i64 as a portable, intermediate type for !mpi.comm.
800  // It would be nicer to somehow get the right type directly, but TLDI is not
801  // available here.
802  converter.addConversion([](mpi::CommType type) {
803  return IntegerType::get(type.getContext(), 64);
804  });
805  patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
806  FinalizeOpLowering, InitOpLowering, SendOpLowering,
807  RecvOpLowering, AllReduceOpLowering>(converter);
808 }
809 
811  registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
812  dialect->addInterfaces<FuncToLLVMDialectInterface>();
813  });
814 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition: DLTI.cpp:537
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Definition: MPIToLLVM.cpp:797
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
Definition: MPIToLLVM.cpp:810
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.