MLIR  21.0.0git
MPIToLLVM.cpp
Go to the documentation of this file.
1 //===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 //
10 // Copyright (C) by Argonne National Laboratory
11 // See COPYRIGHT in top-level directory
12 // of MPICH source repository.
13 //
14 
19 #include "mlir/Dialect/DLTI/DLTI.h"
24 #include <memory>
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 template <typename Op, typename... Args>
31 static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
32  ConversionPatternRewriter &rewriter, StringRef name,
33  Args &&...args) {
34  Op ret;
35  if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
36  ConversionPatternRewriter::InsertionGuard guard(rewriter);
37  rewriter.setInsertionPointToStart(moduleOp.getBody());
38  ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
39  }
40  return ret;
41 }
42 
43 static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
44  const Location loc,
45  ConversionPatternRewriter &rewriter,
46  StringRef name,
47  LLVM::LLVMFunctionType type) {
48  return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49  moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
50 }
51 
52 std::pair<Value, Value> getRawPtrAndSize(const Location loc,
53  ConversionPatternRewriter &rewriter,
54  Value memRef, Type elType) {
55  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
56  Value dataPtr =
57  rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
58  Value offset = rewriter.create<LLVM::ExtractValueOp>(
59  loc, rewriter.getI64Type(), memRef, 2);
60  Value resPtr =
61  rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
62  Value size;
63  if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
64  size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
65  ArrayRef<int64_t>{3, 0});
66  size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
67  } else {
68  size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
69  }
70  return {resPtr, size};
71 }
72 
73 /// When lowering the mpi dialect to functions calls certain details
74 /// differ between various MPI implementations. This class will provide
75 /// these in a generic way, depending on the MPI implementation that got
76 /// selected by the DLTI attribute on the module.
77 class MPIImplTraits {
78  ModuleOp &moduleOp;
79 
80 public:
81  /// Instantiate a new MPIImplTraits object according to the DLTI attribute
82  /// on the given module. Default to MPICH if no attribute is present or
83  /// the value is unknown.
84  static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
85 
86  explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
87 
88  virtual ~MPIImplTraits() = default;
89 
90  ModuleOp &getModuleOp() { return moduleOp; }
91 
92  /// Gets or creates MPI_COMM_WORLD as a Value.
93  /// Different MPI implementations have different communicator types.
94  /// Using i64 as a portable, intermediate type.
95  /// Appropriate cast needs to take place before calling MPI functions.
96  virtual Value getCommWorld(const Location loc,
97  ConversionPatternRewriter &rewriter) = 0;
98 
99  /// Type converter provides i64 type for communicator type.
100  /// Converts to native type, which might be ptr or int or whatever.
101  virtual Value castComm(const Location loc,
102  ConversionPatternRewriter &rewriter, Value comm) = 0;
103 
104  /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
105  virtual intptr_t getStatusIgnore() = 0;
106 
107  /// Get the MPI_IN_PLACE value (void *).
108  virtual void *getInPlace() = 0;
109 
110  /// Gets or creates an MPI datatype as a value which corresponds to the given
111  /// type.
112  virtual Value getDataType(const Location loc,
113  ConversionPatternRewriter &rewriter, Type type) = 0;
114 
115  /// Gets or creates an MPI_Op value which corresponds to the given
116  /// enum value.
117  virtual Value getMPIOp(const Location loc,
118  ConversionPatternRewriter &rewriter,
119  mpi::MPI_OpClassEnum opAttr) = 0;
120 };
121 
122 //===----------------------------------------------------------------------===//
123 // Implementation details for MPICH ABI compatible MPI implementations
124 //===----------------------------------------------------------------------===//
125 
126 class MPICHImplTraits : public MPIImplTraits {
127  static constexpr int MPI_FLOAT = 0x4c00040a;
128  static constexpr int MPI_DOUBLE = 0x4c00080b;
129  static constexpr int MPI_INT8_T = 0x4c000137;
130  static constexpr int MPI_INT16_T = 0x4c000238;
131  static constexpr int MPI_INT32_T = 0x4c000439;
132  static constexpr int MPI_INT64_T = 0x4c00083a;
133  static constexpr int MPI_UINT8_T = 0x4c00013b;
134  static constexpr int MPI_UINT16_T = 0x4c00023c;
135  static constexpr int MPI_UINT32_T = 0x4c00043d;
136  static constexpr int MPI_UINT64_T = 0x4c00083e;
137  static constexpr int MPI_MAX = 0x58000001;
138  static constexpr int MPI_MIN = 0x58000002;
139  static constexpr int MPI_SUM = 0x58000003;
140  static constexpr int MPI_PROD = 0x58000004;
141  static constexpr int MPI_LAND = 0x58000005;
142  static constexpr int MPI_BAND = 0x58000006;
143  static constexpr int MPI_LOR = 0x58000007;
144  static constexpr int MPI_BOR = 0x58000008;
145  static constexpr int MPI_LXOR = 0x58000009;
146  static constexpr int MPI_BXOR = 0x5800000a;
147  static constexpr int MPI_MINLOC = 0x5800000b;
148  static constexpr int MPI_MAXLOC = 0x5800000c;
149  static constexpr int MPI_REPLACE = 0x5800000d;
150  static constexpr int MPI_NO_OP = 0x5800000e;
151 
152 public:
153  using MPIImplTraits::MPIImplTraits;
154 
155  ~MPICHImplTraits() override = default;
156 
157  Value getCommWorld(const Location loc,
158  ConversionPatternRewriter &rewriter) override {
159  static constexpr int MPI_COMM_WORLD = 0x44000000;
160  return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
161  MPI_COMM_WORLD);
162  }
163 
164  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
165  Value comm) override {
166  return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
167  }
168 
169  intptr_t getStatusIgnore() override { return 1; }
170 
171  void *getInPlace() override { return reinterpret_cast<void *>(-1); }
172 
173  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
174  Type type) override {
175  int32_t mtype = 0;
176  if (type.isF32())
177  mtype = MPI_FLOAT;
178  else if (type.isF64())
179  mtype = MPI_DOUBLE;
180  else if (type.isInteger(64) && !type.isUnsignedInteger())
181  mtype = MPI_INT64_T;
182  else if (type.isInteger(64))
183  mtype = MPI_UINT64_T;
184  else if (type.isInteger(32) && !type.isUnsignedInteger())
185  mtype = MPI_INT32_T;
186  else if (type.isInteger(32))
187  mtype = MPI_UINT32_T;
188  else if (type.isInteger(16) && !type.isUnsignedInteger())
189  mtype = MPI_INT16_T;
190  else if (type.isInteger(16))
191  mtype = MPI_UINT16_T;
192  else if (type.isInteger(8) && !type.isUnsignedInteger())
193  mtype = MPI_INT8_T;
194  else if (type.isInteger(8))
195  mtype = MPI_UINT8_T;
196  else
197  assert(false && "unsupported type");
198  return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
199  }
200 
201  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
202  mpi::MPI_OpClassEnum opAttr) override {
203  int32_t op = MPI_NO_OP;
204  switch (opAttr) {
205  case mpi::MPI_OpClassEnum::MPI_OP_NULL:
206  op = MPI_NO_OP;
207  break;
208  case mpi::MPI_OpClassEnum::MPI_MAX:
209  op = MPI_MAX;
210  break;
211  case mpi::MPI_OpClassEnum::MPI_MIN:
212  op = MPI_MIN;
213  break;
214  case mpi::MPI_OpClassEnum::MPI_SUM:
215  op = MPI_SUM;
216  break;
217  case mpi::MPI_OpClassEnum::MPI_PROD:
218  op = MPI_PROD;
219  break;
220  case mpi::MPI_OpClassEnum::MPI_LAND:
221  op = MPI_LAND;
222  break;
223  case mpi::MPI_OpClassEnum::MPI_BAND:
224  op = MPI_BAND;
225  break;
226  case mpi::MPI_OpClassEnum::MPI_LOR:
227  op = MPI_LOR;
228  break;
229  case mpi::MPI_OpClassEnum::MPI_BOR:
230  op = MPI_BOR;
231  break;
232  case mpi::MPI_OpClassEnum::MPI_LXOR:
233  op = MPI_LXOR;
234  break;
235  case mpi::MPI_OpClassEnum::MPI_BXOR:
236  op = MPI_BXOR;
237  break;
238  case mpi::MPI_OpClassEnum::MPI_MINLOC:
239  op = MPI_MINLOC;
240  break;
241  case mpi::MPI_OpClassEnum::MPI_MAXLOC:
242  op = MPI_MAXLOC;
243  break;
244  case mpi::MPI_OpClassEnum::MPI_REPLACE:
245  op = MPI_REPLACE;
246  break;
247  }
248  return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
249  }
250 };
251 
252 //===----------------------------------------------------------------------===//
253 // Implementation details for OpenMPI
254 //===----------------------------------------------------------------------===//
255 class OMPIImplTraits : public MPIImplTraits {
256  LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
257  ConversionPatternRewriter &rewriter,
258  StringRef name,
259  LLVM::LLVMStructType type) {
260 
261  return getOrDefineGlobal<LLVM::GlobalOp>(
262  getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
263  LLVM::Linkage::External, name,
264  /*value=*/Attribute(), /*alignment=*/0, 0);
265  }
266 
267 public:
268  using MPIImplTraits::MPIImplTraits;
269 
270  ~OMPIImplTraits() override = default;
271 
272  Value getCommWorld(const Location loc,
273  ConversionPatternRewriter &rewriter) override {
274  auto context = rewriter.getContext();
275  // get external opaque struct pointer type
276  auto commStructT =
277  LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
278  StringRef name = "ompi_mpi_comm_world";
279 
280  // make sure global op definition exists
281  getOrDefineExternalStruct(loc, rewriter, name, commStructT);
282 
283  // get address of symbol
284  auto comm = rewriter.create<LLVM::AddressOfOp>(
285  loc, LLVM::LLVMPointerType::get(context),
286  SymbolRefAttr::get(context, name));
287  return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
288  }
289 
290  Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
291  Value comm) override {
292  return rewriter.create<LLVM::IntToPtrOp>(
293  loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
294  }
295 
296  intptr_t getStatusIgnore() override { return 0; }
297 
298  void *getInPlace() override { return reinterpret_cast<void *>(1); }
299 
300  Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
301  Type type) override {
302  StringRef mtype;
303  if (type.isF32())
304  mtype = "ompi_mpi_float";
305  else if (type.isF64())
306  mtype = "ompi_mpi_double";
307  else if (type.isInteger(64) && !type.isUnsignedInteger())
308  mtype = "ompi_mpi_int64_t";
309  else if (type.isInteger(64))
310  mtype = "ompi_mpi_uint64_t";
311  else if (type.isInteger(32) && !type.isUnsignedInteger())
312  mtype = "ompi_mpi_int32_t";
313  else if (type.isInteger(32))
314  mtype = "ompi_mpi_uint32_t";
315  else if (type.isInteger(16) && !type.isUnsignedInteger())
316  mtype = "ompi_mpi_int16_t";
317  else if (type.isInteger(16))
318  mtype = "ompi_mpi_uint16_t";
319  else if (type.isInteger(8) && !type.isUnsignedInteger())
320  mtype = "ompi_mpi_int8_t";
321  else if (type.isInteger(8))
322  mtype = "ompi_mpi_uint8_t";
323  else
324  assert(false && "unsupported type");
325 
326  auto context = rewriter.getContext();
327  // get external opaque struct pointer type
328  auto typeStructT =
329  LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
330  // make sure global op definition exists
331  getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
332  // get address of symbol
333  return rewriter.create<LLVM::AddressOfOp>(
334  loc, LLVM::LLVMPointerType::get(context),
335  SymbolRefAttr::get(context, mtype));
336  }
337 
338  Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
339  mpi::MPI_OpClassEnum opAttr) override {
340  StringRef op;
341  switch (opAttr) {
342  case mpi::MPI_OpClassEnum::MPI_OP_NULL:
343  op = "ompi_mpi_no_op";
344  break;
345  case mpi::MPI_OpClassEnum::MPI_MAX:
346  op = "ompi_mpi_max";
347  break;
348  case mpi::MPI_OpClassEnum::MPI_MIN:
349  op = "ompi_mpi_min";
350  break;
351  case mpi::MPI_OpClassEnum::MPI_SUM:
352  op = "ompi_mpi_sum";
353  break;
354  case mpi::MPI_OpClassEnum::MPI_PROD:
355  op = "ompi_mpi_prod";
356  break;
357  case mpi::MPI_OpClassEnum::MPI_LAND:
358  op = "ompi_mpi_land";
359  break;
360  case mpi::MPI_OpClassEnum::MPI_BAND:
361  op = "ompi_mpi_band";
362  break;
363  case mpi::MPI_OpClassEnum::MPI_LOR:
364  op = "ompi_mpi_lor";
365  break;
366  case mpi::MPI_OpClassEnum::MPI_BOR:
367  op = "ompi_mpi_bor";
368  break;
369  case mpi::MPI_OpClassEnum::MPI_LXOR:
370  op = "ompi_mpi_lxor";
371  break;
372  case mpi::MPI_OpClassEnum::MPI_BXOR:
373  op = "ompi_mpi_bxor";
374  break;
375  case mpi::MPI_OpClassEnum::MPI_MINLOC:
376  op = "ompi_mpi_minloc";
377  break;
378  case mpi::MPI_OpClassEnum::MPI_MAXLOC:
379  op = "ompi_mpi_maxloc";
380  break;
381  case mpi::MPI_OpClassEnum::MPI_REPLACE:
382  op = "ompi_mpi_replace";
383  break;
384  }
385  auto context = rewriter.getContext();
386  // get external opaque struct pointer type
387  auto opStructT =
388  LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
389  // make sure global op definition exists
390  getOrDefineExternalStruct(loc, rewriter, op, opStructT);
391  // get address of symbol
392  return rewriter.create<LLVM::AddressOfOp>(
393  loc, LLVM::LLVMPointerType::get(context),
394  SymbolRefAttr::get(context, op));
395  }
396 };
397 
398 std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
399  auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
400  if (failed(attr))
401  return std::make_unique<MPICHImplTraits>(moduleOp);
402  auto strAttr = dyn_cast<StringAttr>(attr.value());
403  if (strAttr && strAttr.getValue() == "OpenMPI")
404  return std::make_unique<OMPIImplTraits>(moduleOp);
405  if (!strAttr || strAttr.getValue() != "MPICH")
406  moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
407  << strAttr.getValue() << "), defaulting to MPICH";
408  return std::make_unique<MPICHImplTraits>(moduleOp);
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // InitOpLowering
413 //===----------------------------------------------------------------------===//
414 
415 struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
417 
418  LogicalResult
419  matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
420  ConversionPatternRewriter &rewriter) const override {
421  Location loc = op.getLoc();
422 
423  // ptrType `!llvm.ptr`
424  Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
425 
426  // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
427  auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
428  Value llvmnull = nullPtrOp.getRes();
429 
430  // grab a reference to the global module op:
431  auto moduleOp = op->getParentOfType<ModuleOp>();
432 
433  // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
434  auto initFuncType =
435  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
436  // get or create function declaration:
437  LLVM::LLVMFuncOp initDecl =
438  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
439 
440  // replace init with function call
441  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
442  ValueRange{llvmnull, llvmnull});
443 
444  return success();
445  }
446 };
447 
448 //===----------------------------------------------------------------------===//
449 // FinalizeOpLowering
450 //===----------------------------------------------------------------------===//
451 
452 struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
454 
455  LogicalResult
456  matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
457  ConversionPatternRewriter &rewriter) const override {
458  // get loc
459  Location loc = op.getLoc();
460 
461  // grab a reference to the global module op:
462  auto moduleOp = op->getParentOfType<ModuleOp>();
463 
464  // LLVM Function type representing `i32 MPI_Finalize()`
465  auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
466  // get or create function declaration:
467  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
468  moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
469 
470  // replace init with function call
471  rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
472 
473  return success();
474  }
475 };
476 
477 //===----------------------------------------------------------------------===//
478 // CommWorldOpLowering
479 //===----------------------------------------------------------------------===//
480 
481 struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
483 
484  LogicalResult
485  matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override {
487  // grab a reference to the global module op:
488  auto moduleOp = op->getParentOfType<ModuleOp>();
489  auto mpiTraits = MPIImplTraits::get(moduleOp);
490  // get MPI_COMM_WORLD
491  rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
492 
493  return success();
494  }
495 };
496 
497 //===----------------------------------------------------------------------===//
498 // CommSplitOpLowering
499 //===----------------------------------------------------------------------===//
500 
501 struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
503 
504  LogicalResult
505  matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  // grab a reference to the global module op:
508  auto moduleOp = op->getParentOfType<ModuleOp>();
509  auto mpiTraits = MPIImplTraits::get(moduleOp);
510  Type i32 = rewriter.getI32Type();
511  Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
512  Location loc = op.getLoc();
513 
514  // get communicator
515  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
516  auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
517  auto outPtr =
518  rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
519 
520  // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
521  auto funcType =
522  LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
523  // get or create function declaration:
524  LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
525  "MPI_Comm_split", funcType);
526 
527  auto callOp = rewriter.create<LLVM::CallOp>(
528  loc, funcDecl,
529  ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
530  outPtr.getRes()});
531 
532  // load the communicator into a register
533  Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
534  res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
535 
536  // if retval is checked, replace uses of retval with the results from the
537  // call op
538  SmallVector<Value> replacements;
539  if (op.getRetval())
540  replacements.push_back(callOp.getResult());
541 
542  // replace op
543  replacements.push_back(res);
544  rewriter.replaceOp(op, replacements);
545 
546  return success();
547  }
548 };
549 
550 //===----------------------------------------------------------------------===//
551 // CommRankOpLowering
552 //===----------------------------------------------------------------------===//
553 
554 struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
556 
557  LogicalResult
558  matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
559  ConversionPatternRewriter &rewriter) const override {
560  // get some helper vars
561  Location loc = op.getLoc();
562  MLIRContext *context = rewriter.getContext();
563  Type i32 = rewriter.getI32Type();
564 
565  // ptrType `!llvm.ptr`
566  Type ptrType = LLVM::LLVMPointerType::get(context);
567 
568  // grab a reference to the global module op:
569  auto moduleOp = op->getParentOfType<ModuleOp>();
570 
571  auto mpiTraits = MPIImplTraits::get(moduleOp);
572  // get communicator
573  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
574 
575  // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
576  auto rankFuncType =
577  LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
578  // get or create function declaration:
579  LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
580  moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
581 
582  // replace with function call
583  auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
584  auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
585  auto callOp = rewriter.create<LLVM::CallOp>(
586  loc, initDecl, ValueRange{comm, rankptr.getRes()});
587 
588  // load the rank into a register
589  auto loadedRank =
590  rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
591 
592  // if retval is checked, replace uses of retval with the results from the
593  // call op
594  SmallVector<Value> replacements;
595  if (op.getRetval())
596  replacements.push_back(callOp.getResult());
597 
598  // replace all uses, then erase op
599  replacements.push_back(loadedRank.getRes());
600  rewriter.replaceOp(op, replacements);
601 
602  return success();
603  }
604 };
605 
606 //===----------------------------------------------------------------------===//
607 // SendOpLowering
608 //===----------------------------------------------------------------------===//
609 
610 struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
612 
613  LogicalResult
614  matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
615  ConversionPatternRewriter &rewriter) const override {
616  // get some helper vars
617  Location loc = op.getLoc();
618  MLIRContext *context = rewriter.getContext();
619  Type i32 = rewriter.getI32Type();
620  Type elemType = op.getRef().getType().getElementType();
621 
622  // ptrType `!llvm.ptr`
623  Type ptrType = LLVM::LLVMPointerType::get(context);
624 
625  // grab a reference to the global module op:
626  auto moduleOp = op->getParentOfType<ModuleOp>();
627 
628  // get MPI_COMM_WORLD, dataType and pointer
629  auto [dataPtr, size] =
630  getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
631  auto mpiTraits = MPIImplTraits::get(moduleOp);
632  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
633  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
634 
635  // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
636  // tag, comm)`
637  auto funcType = LLVM::LLVMFunctionType::get(
638  i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
639  // get or create function declaration:
640  LLVM::LLVMFuncOp funcDecl =
641  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
642 
643  // replace op with function call
644  auto funcCall = rewriter.create<LLVM::CallOp>(
645  loc, funcDecl,
646  ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
647  comm});
648  if (op.getRetval())
649  rewriter.replaceOp(op, funcCall.getResult());
650  else
651  rewriter.eraseOp(op);
652 
653  return success();
654  }
655 };
656 
657 //===----------------------------------------------------------------------===//
658 // RecvOpLowering
659 //===----------------------------------------------------------------------===//
660 
661 struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
663 
664  LogicalResult
665  matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const override {
667  // get some helper vars
668  Location loc = op.getLoc();
669  MLIRContext *context = rewriter.getContext();
670  Type i32 = rewriter.getI32Type();
671  Type i64 = rewriter.getI64Type();
672  Type elemType = op.getRef().getType().getElementType();
673 
674  // ptrType `!llvm.ptr`
675  Type ptrType = LLVM::LLVMPointerType::get(context);
676 
677  // grab a reference to the global module op:
678  auto moduleOp = op->getParentOfType<ModuleOp>();
679 
680  // get MPI_COMM_WORLD, dataType, status_ignore and pointer
681  auto [dataPtr, size] =
682  getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
683  auto mpiTraits = MPIImplTraits::get(moduleOp);
684  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
685  Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
686  Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
687  loc, i64, mpiTraits->getStatusIgnore());
688  statusIgnore =
689  rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
690 
691  // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
692  // tag, comm)`
693  auto funcType =
694  LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
695  i32, comm.getType(), ptrType});
696  // get or create function declaration:
697  LLVM::LLVMFuncOp funcDecl =
698  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
699 
700  // replace op with function call
701  auto funcCall = rewriter.create<LLVM::CallOp>(
702  loc, funcDecl,
703  ValueRange{dataPtr, size, dataType, adaptor.getSource(),
704  adaptor.getTag(), comm, statusIgnore});
705  if (op.getRetval())
706  rewriter.replaceOp(op, funcCall.getResult());
707  else
708  rewriter.eraseOp(op);
709 
710  return success();
711  }
712 };
713 
714 //===----------------------------------------------------------------------===//
715 // AllReduceOpLowering
716 //===----------------------------------------------------------------------===//
717 
718 struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
720 
721  LogicalResult
722  matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
723  ConversionPatternRewriter &rewriter) const override {
724  Location loc = op.getLoc();
725  MLIRContext *context = rewriter.getContext();
726  Type i32 = rewriter.getI32Type();
727  Type i64 = rewriter.getI64Type();
728  Type elemType = op.getSendbuf().getType().getElementType();
729 
730  // ptrType `!llvm.ptr`
731  Type ptrType = LLVM::LLVMPointerType::get(context);
732  auto moduleOp = op->getParentOfType<ModuleOp>();
733  auto mpiTraits = MPIImplTraits::get(moduleOp);
734  auto [sendPtr, sendSize] =
735  getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
736  auto [recvPtr, recvSize] =
737  getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
738 
739  // If input and output are the same, request in-place operation.
740  if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
741  sendPtr = rewriter.create<LLVM::ConstantOp>(
742  loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
743  sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
744  }
745 
746  Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
747  Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
748  Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
749 
750  // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
751  // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
752  auto funcType = LLVM::LLVMFunctionType::get(
753  i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
754  commWorld.getType()});
755  // get or create function declaration:
756  LLVM::LLVMFuncOp funcDecl =
757  getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
758 
759  // replace op with function call
760  auto funcCall = rewriter.create<LLVM::CallOp>(
761  loc, funcDecl,
762  ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
763 
764  if (op.getRetval())
765  rewriter.replaceOp(op, funcCall.getResult());
766  else
767  rewriter.eraseOp(op);
768 
769  return success();
770  }
771 };
772 
773 //===----------------------------------------------------------------------===//
774 // ConvertToLLVMPatternInterface implementation
775 //===----------------------------------------------------------------------===//
776 
777 /// Implement the interface to convert Func to LLVM.
778 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
780  /// Hook for derived dialect interface to provide conversion patterns
781  /// and mark dialect legal for the conversion target.
782  void populateConvertToLLVMConversionPatterns(
783  ConversionTarget &target, LLVMTypeConverter &typeConverter,
784  RewritePatternSet &patterns) const final {
786  }
787 };
788 } // namespace
789 
790 //===----------------------------------------------------------------------===//
791 // Pattern Population
792 //===----------------------------------------------------------------------===//
793 
796  // Using i64 as a portable, intermediate type for !mpi.comm.
797  // It would be nicer to somehow get the right type directly, but TLDI is not
798  // available here.
799  converter.addConversion([](mpi::CommType type) {
800  return IntegerType::get(type.getContext(), 64);
801  });
802  patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
803  FinalizeOpLowering, InitOpLowering, SendOpLowering,
804  RecvOpLowering, AllReduceOpLowering>(converter);
805 }
806 
808  registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
809  dialect->addInterfaces<FuncToLLVMDialectInterface>();
810  });
811 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition: DLTI.cpp:527
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Definition: MPIToLLVM.cpp:794
void registerConvertMPIToLLVMInterface(DialectRegistry &registry)
Definition: MPIToLLVM.cpp:807
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc, OpBuilder &b, StringRef name, LLVM::LLVMFunctionType type)
Find or create an external function declaration in the given module.