MLIR  20.0.0git
MeshToMPI.cpp
Go to the documentation of this file.
1 //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation of Mesh communication ops tp MPI ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
23 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/IR/SymbolTable.h"
29 
30 #define DEBUG_TYPE "mesh-to-mpi"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 
33 namespace mlir {
34 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
35 #include "mlir/Conversion/Passes.h.inc"
36 } // namespace mlir
37 
38 using namespace mlir;
39 using namespace mlir::mesh;
40 
41 namespace {
42 // Create operations converting a linear index to a multi-dimensional index
43 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
44  Value linearIndex,
45  ValueRange dimensions) {
46  int n = dimensions.size();
47  SmallVector<Value> multiIndex(n);
48 
49  for (int i = n - 1; i >= 0; --i) {
50  multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
51  if (i > 0) {
52  linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
53  }
54  }
55 
56  return multiIndex;
57 }
58 
59 // Create operations converting a multi-dimensional index to a linear index
60 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
61  ValueRange dimensions) {
62 
63  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
64  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
65 
66  for (int i = multiIndex.size() - 1; i >= 0; --i) {
67  auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
68  linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
69  stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
70  }
71 
72  return linearIndex;
73 }
74 
75 struct ConvertProcessMultiIndexOp
76  : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
78 
79  mlir::LogicalResult
80  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
81  mlir::PatternRewriter &rewriter) const override {
82 
83  // Currently converts its linear index to a multi-dimensional index.
84 
85  SymbolTableCollection symbolTableCollection;
86  auto loc = op.getLoc();
87  auto meshOp = getMesh(op, symbolTableCollection);
88  // For now we only support static mesh shapes
89  if (ShapedType::isDynamicShape(meshOp.getShape())) {
90  return mlir::failure();
91  }
92 
93  SmallVector<Value> dims;
94  llvm::transform(
95  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
96  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
97  });
98  auto rank =
99  rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
100  auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
101 
102  // optionally extract subset of mesh axes
103  auto axes = op.getAxes();
104  if (!axes.empty()) {
105  SmallVector<Value> subIndex;
106  for (auto axis : axes) {
107  subIndex.push_back(mIdx[axis]);
108  }
109  mIdx = subIndex;
110  }
111 
112  rewriter.replaceOp(op, mIdx);
113  return mlir::success();
114  }
115 };
116 
117 struct ConvertProcessLinearIndexOp
118  : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
120 
121  mlir::LogicalResult
122  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
123  mlir::PatternRewriter &rewriter) const override {
124 
125  // Finds a global named "static_mpi_rank" it will use that splat value.
126  // Otherwise it defaults to mpi.comm_rank.
127 
128  auto loc = op.getLoc();
129  auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
130  if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
131  op, rankOpName)) {
132  if (auto initTnsr = globalOp.getInitialValueAttr()) {
133  auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
134  rewriter.replaceOp(op,
135  rewriter.create<arith::ConstantIndexOp>(loc, val));
136  return mlir::success();
137  }
138  }
139  auto rank =
140  rewriter
141  .create<mpi::CommRankOp>(
142  op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()),
143  rewriter.getI32Type()})
144  .getRank();
145  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
146  rank);
147  return mlir::success();
148  }
149 };
150 
151 struct ConvertNeighborsLinearIndicesOp
152  : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
154 
155  mlir::LogicalResult
156  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
157  mlir::PatternRewriter &rewriter) const override {
158 
159  // Computes the neighbors indices along a split axis by simply
160  // adding/subtracting 1 to the current index in that dimension.
161  // Assigns -1 if neighbor is out of bounds.
162 
163  auto axes = op.getSplitAxes();
164  // For now only single axis sharding is supported
165  if (axes.size() != 1) {
166  return mlir::failure();
167  }
168 
169  auto loc = op.getLoc();
170  SymbolTableCollection symbolTableCollection;
171  auto meshOp = getMesh(op, symbolTableCollection);
172  auto mIdx = op.getDevice();
173  auto orgIdx = mIdx[axes[0]];
174  SmallVector<Value> dims;
175  llvm::transform(
176  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
177  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
178  });
179  auto dimSz = dims[axes[0]];
180  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
181  auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
182  auto atBorder = rewriter.create<arith::CmpIOp>(
183  loc, arith::CmpIPredicate::sle, orgIdx,
184  rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
185  auto down = rewriter.create<scf::IfOp>(
186  loc, atBorder,
187  [&](OpBuilder &builder, Location loc) {
188  builder.create<scf::YieldOp>(loc, minus1);
189  },
190  [&](OpBuilder &builder, Location loc) {
191  SmallVector<Value> tmp = mIdx;
192  tmp[axes[0]] =
193  rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
194  .getResult();
195  builder.create<scf::YieldOp>(
196  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
197  });
198  atBorder = rewriter.create<arith::CmpIOp>(
199  loc, arith::CmpIPredicate::sge, orgIdx,
200  rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
201  auto up = rewriter.create<scf::IfOp>(
202  loc, atBorder,
203  [&](OpBuilder &builder, Location loc) {
204  builder.create<scf::YieldOp>(loc, minus1);
205  },
206  [&](OpBuilder &builder, Location loc) {
207  SmallVector<Value> tmp = mIdx;
208  tmp[axes[0]] =
209  rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
210  .getResult();
211  builder.create<scf::YieldOp>(
212  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
213  });
214  rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
215  return mlir::success();
216  }
217 };
218 
219 struct ConvertUpdateHaloOp
220  : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
222 
223  mlir::LogicalResult
224  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
225  mlir::PatternRewriter &rewriter) const override {
226 
227  // The input/output memref is assumed to be in C memory order.
228  // Halos are exchanged as 2 blocks per dimension (one for each side: down
229  // and up). For each haloed dimension `d`, the exchanged blocks are
230  // expressed as multi-dimensional subviews. The subviews include potential
231  // halos of higher dimensions `dh > d`, no halos for the lower dimensions
232  // `dl < d` and for dimension `d` the currently exchanged halo only.
233  // By iterating form higher to lower dimensions this also updates the halos
234  // in the 'corners'.
235  // memref.subview is used to read and write the halo data from and to the
236  // local data. Because subviews and halos can have mixed dynamic and static
237  // shapes, OpFoldResults are used whenever possible.
238 
239  SymbolTableCollection symbolTableCollection;
240  auto loc = op.getLoc();
241 
242  // convert a OpFoldResult into a Value
243  auto toValue = [&rewriter, &loc](OpFoldResult &v) {
244  return v.is<Value>()
245  ? v.get<Value>()
246  : rewriter.create<::mlir::arith::ConstantOp>(
247  loc,
248  rewriter.getIndexAttr(
249  cast<IntegerAttr>(v.get<Attribute>()).getInt()));
250  };
251 
252  auto dest = op.getDestination();
253  auto dstShape = cast<ShapedType>(dest.getType()).getShape();
254  Value array = dest;
255  if (isa<RankedTensorType>(array.getType())) {
256  // If the destination is a memref, we need to cast it to a tensor
257  auto tensorType = MemRefType::get(
258  dstShape, cast<ShapedType>(array.getType()).getElementType());
259  array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
260  .getResult();
261  }
262  auto rank = cast<ShapedType>(array.getType()).getRank();
263  auto opSplitAxes = op.getSplitAxes().getAxes();
264  auto mesh = op.getMesh();
265  auto meshOp = getMesh(op, symbolTableCollection);
266  auto haloSizes =
267  getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
268  // subviews need Index values
269  for (auto &sz : haloSizes) {
270  if (sz.is<Value>()) {
271  sz = rewriter
272  .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
273  sz.get<Value>())
274  .getResult();
275  }
276  }
277 
278  // most of the offset/size/stride data is the same for all dims
279  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
280  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
281  SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
282  auto currHaloDim = -1; // halo sizes are provided for split dimensions only
283  // we need the actual shape to compute offsets and sizes
284  for (auto i = 0; i < rank; ++i) {
285  auto s = dstShape[i];
286  if (ShapedType::isDynamic(s)) {
287  shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
288  } else {
289  shape[i] = rewriter.getIndexAttr(s);
290  }
291 
292  if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
293  ++currHaloDim;
294  // the offsets for lower dim sstarts after their down halo
295  offsets[i] = haloSizes[currHaloDim * 2];
296 
297  // prepare shape and offsets of highest dim's halo exchange
298  auto _haloSz =
299  rewriter
300  .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
301  toValue(haloSizes[currHaloDim * 2 + 1]))
302  .getResult();
303  // the halo shape of lower dims exlude the halos
304  dimSizes[i] =
305  rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
306  .getResult();
307  } else {
308  dimSizes[i] = shape[i];
309  }
310  }
311 
312  auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
313  auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
314  auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
315  auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
316 
317  SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
318  rewriter.getIndexType());
319  auto myMultiIndex =
320  rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
321  .getResult();
322  // traverse all split axes from high to low dim
323  for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
324  auto splitAxes = opSplitAxes[dim];
325  if (splitAxes.empty()) {
326  continue;
327  }
328  assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
329  // Get the linearized ids of the neighbors (down and up) for the
330  // given split
331  auto tmp = rewriter
332  .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
333  splitAxes)
334  .getResults();
335  // MPI operates on i32...
336  Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
337  loc, rewriter.getI32Type(), tmp[0]),
338  rewriter.create<arith::IndexCastOp>(
339  loc, rewriter.getI32Type(), tmp[1])};
340 
341  auto lowerRecvOffset = rewriter.getIndexAttr(0);
342  auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
343  auto upperRecvOffset = rewriter.create<arith::SubIOp>(
344  loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
345  auto upperSendOffset = rewriter.create<arith::SubIOp>(
346  loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
347 
348  // Make sure we send/recv in a way that does not lead to a dead-lock.
349  // The current approach is by far not optimal, this should be at least
350  // be a red-black pattern or using MPI_sendrecv.
351  // Also, buffers should be re-used.
352  // Still using temporary contiguous buffers for MPI communication...
353  // Still yielding a "serialized" communication pattern...
354  auto genSendRecv = [&](bool upperHalo) {
355  auto orgOffset = offsets[dim];
356  dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
357  : haloSizes[currHaloDim * 2];
358  // Check if we need to send and/or receive
359  // Processes on the mesh borders have only one neighbor
360  auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
361  auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
362  auto hasFrom = rewriter.create<arith::CmpIOp>(
363  loc, arith::CmpIPredicate::sge, from, zero);
364  auto hasTo = rewriter.create<arith::CmpIOp>(
365  loc, arith::CmpIPredicate::sge, to, zero);
366  auto buffer = rewriter.create<memref::AllocOp>(
367  loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
368  // if has neighbor: copy halo data from array to buffer and send
369  rewriter.create<scf::IfOp>(
370  loc, hasTo, [&](OpBuilder &builder, Location loc) {
371  offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
372  : OpFoldResult(upperSendOffset);
373  auto subview = builder.create<memref::SubViewOp>(
374  loc, array, offsets, dimSizes, strides);
375  builder.create<memref::CopyOp>(loc, subview, buffer);
376  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
377  builder.create<scf::YieldOp>(loc);
378  });
379  // if has neighbor: receive halo data into buffer and copy to array
380  rewriter.create<scf::IfOp>(
381  loc, hasFrom, [&](OpBuilder &builder, Location loc) {
382  offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
383  : OpFoldResult(lowerRecvOffset);
384  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
385  auto subview = builder.create<memref::SubViewOp>(
386  loc, array, offsets, dimSizes, strides);
387  builder.create<memref::CopyOp>(loc, buffer, subview);
388  builder.create<scf::YieldOp>(loc);
389  });
390  rewriter.create<memref::DeallocOp>(loc, buffer);
391  offsets[dim] = orgOffset;
392  };
393 
394  genSendRecv(false);
395  genSendRecv(true);
396 
397  // the shape for lower dims include higher dims' halos
398  dimSizes[dim] = shape[dim];
399  // -> the offset for higher dims is always 0
400  offsets[dim] = rewriter.getIndexAttr(0);
401  // on to next halo
402  --currHaloDim;
403  }
404 
405  if (isa<MemRefType>(op.getResult().getType())) {
406  rewriter.replaceOp(op, array);
407  } else {
408  assert(isa<RankedTensorType>(op.getResult().getType()));
409  rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
410  loc, op.getResult().getType(), array,
411  /*restrict=*/true, /*writable=*/true));
412  }
413  return mlir::success();
414  }
415 };
416 
417 struct ConvertMeshToMPIPass
418  : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
419  using Base::Base;
420 
421  /// Run the dialect converter on the module.
422  void runOnOperation() override {
423  auto *ctx = &getContext();
425 
426  patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
427  ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
428  ctx);
429 
430  (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
431  }
432 };
433 
434 } // namespace
435 
436 // Create a pass that convert Mesh to MPI
437 std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() {
438  return std::make_unique<ConvertMeshToMPIPass>();
439 }
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerType getI32Type()
Definition: Builders.cpp:107
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::unique_ptr<::mlir::Pass > createConvertMeshToMPIPass()
Lowers Mesh communication operations (updateHalo, AllGater, ...) to MPI primitives.
Definition: MeshToMPI.cpp:437
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362