MLIR  21.0.0git
MeshToMPI.cpp
Go to the documentation of this file.
1 //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation of Mesh communication ops tp MPI ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/Dialect/DLTI/DLTI.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/SymbolTable.h"
35 
36 #define DEBUG_TYPE "mesh-to-mpi"
37 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 using namespace mesh;
46 
47 namespace {
48 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
49 /// provided type.
50 static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
52  ValueRange dynamics,
53  Type type = Type()) {
54  SmallVector<Value> values;
55  auto dyn = dynamics.begin();
56  Type i64 = b.getI64Type();
57  if (!type)
58  type = i64;
59  assert((i64 == type || b.getIndexType() == type) &&
60  "expected an i64 or an intex type");
61  for (auto s : statics) {
62  if (s == ShapedType::kDynamic) {
63  values.emplace_back(*(dyn++));
64  } else {
65  TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
66  values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
67  }
68  }
69  return values;
70 }
71 
72 /// Create operations converting a linear index to a multi-dimensional index.
73 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
74  Value linearIndex,
75  ValueRange dimensions) {
76  int n = dimensions.size();
77  SmallVector<Value> multiIndex(n);
78 
79  for (int i = n - 1; i >= 0; --i) {
80  multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
81  if (i > 0)
82  linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
83  }
84 
85  return multiIndex;
86 }
87 
88 /// Create operations converting a multi-dimensional index to a linear index.
89 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
90  ValueRange dimensions) {
91 
92  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
93  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
94 
95  for (int i = multiIndex.size() - 1; i >= 0; --i) {
96  Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
97  linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
98  stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
99  }
100 
101  return linearIndex;
102 }
103 
104 /// Replace GetShardingOp with related/dependent ShardingOp.
105 struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
107 
108  LogicalResult
109  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
112  if (!shardOp)
113  return failure();
114  auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
115  if (!shardingOp)
116  return failure();
117 
118  rewriter.replaceOp(op, shardingOp.getResult());
119  return success();
120  }
121 };
122 
123 /// Convert a sharding op to a tuple of tensors of its components
124 /// (SplitAxes, HaloSizes, ShardedDimsOffsets)
125 /// as defined by type converter.
126 struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
128 
129  LogicalResult
130  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
132  auto splitAxes = op.getSplitAxes().getAxes();
133  int64_t maxNAxes = 0;
134  for (auto axes : splitAxes)
135  maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
136 
137  // To hold the split axes, create empty 2d tensor with shape
138  // {splitAxes.size(), max-size-of-split-groups}.
139  // Set trailing elements for smaller split-groups to -1.
140  Location loc = op.getLoc();
141  auto i16 = rewriter.getI16Type();
142  auto i64 = rewriter.getI64Type();
143  std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
144  maxNAxes};
145  Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
146  auto attr = IntegerAttr::get(i16, -1);
147  Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
148  resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
149  .getResult(0);
150 
151  // explicitly write values into tensor row by row
152  std::array<int64_t, 2> strides = {1, 1};
153  int64_t nSplits = 0;
154  ValueRange empty = {};
155  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
156  int64_t size = axes.size();
157  if (size > 0)
158  ++nSplits;
159  std::array<int64_t, 2> offs = {(int64_t)i, 0};
160  std::array<int64_t, 2> sizes = {1, size};
161  auto tensorType = RankedTensorType::get({size}, i16);
162  auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
163  auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
164  resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
165  loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
166  }
167 
168  // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
169  // Store the halo sizes in the tensor.
170  SmallVector<Value> haloSizes =
171  getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
172  adaptor.getDynamicHaloSizes());
173  auto type = RankedTensorType::get({nSplits, 2}, i64);
174  Value resHaloSizes =
175  haloSizes.empty()
176  ? rewriter
177  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
178  i64)
179  .getResult()
180  : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
181  .getResult();
182 
183  // To hold sharded dims offsets, create Tensor with shape {nSplits,
184  // maxSplitSize+1}. Store the offsets in the tensor but set trailing
185  // elements for smaller split-groups to -1. Computing the max size of the
186  // split groups needs using collectiveProcessGroupSize (which needs the
187  // MeshOp)
188  Value resOffsets;
189  if (adaptor.getStaticShardedDimsOffsets().empty()) {
190  resOffsets = rewriter.create<tensor::EmptyOp>(
191  loc, std::array<int64_t, 2>{0, 0}, i64);
192  } else {
193  SymbolTableCollection symbolTableCollection;
194  auto meshOp = getMesh(op, symbolTableCollection);
195  int64_t maxSplitSize = 0;
196  for (auto axes : splitAxes) {
197  int64_t splitSize =
198  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
199  assert(splitSize != ShapedType::kDynamic);
200  maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
201  }
202  assert(maxSplitSize);
203  ++maxSplitSize; // add one for the total size
204 
205  resOffsets = rewriter.create<tensor::EmptyOp>(
206  loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
207  Value zero = rewriter.create<arith::ConstantOp>(
208  loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
209  resOffsets =
210  rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
211  SmallVector<Value> offsets =
212  getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
213  adaptor.getDynamicShardedDimsOffsets());
214  int64_t curr = 0;
215  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
216  int64_t splitSize =
217  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
218  assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
219  ++splitSize; // add one for the total size
220  ArrayRef<Value> values(&offsets[curr], splitSize);
221  Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
222  std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
223  std::array<int64_t, 2> sizes = {1, splitSize};
224  resOffsets = rewriter.create<tensor::InsertSliceOp>(
225  loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
226  curr += splitSize;
227  }
228  }
229 
230  // return a tuple of tensors as defined by type converter
231  SmallVector<Type> resTypes;
232  if (failed(getTypeConverter()->convertType(op.getResult().getType(),
233  resTypes)))
234  return failure();
235 
236  resSplitAxes =
237  rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
238  resHaloSizes =
239  rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
240  resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
241 
242  rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
243  op, TupleType::get(op.getContext(), resTypes),
244  ValueRange{resSplitAxes, resHaloSizes, resOffsets});
245 
246  return success();
247  }
248 };
249 
250 struct ConvertProcessMultiIndexOp
251  : public OpConversionPattern<ProcessMultiIndexOp> {
253 
254  LogicalResult
255  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257 
258  // Currently converts its linear index to a multi-dimensional index.
259 
260  SymbolTableCollection symbolTableCollection;
261  Location loc = op.getLoc();
262  auto meshOp = getMesh(op, symbolTableCollection);
263  // For now we only support static mesh shapes
264  if (ShapedType::isDynamicShape(meshOp.getShape()))
265  return failure();
266 
267  SmallVector<Value> dims;
268  llvm::transform(
269  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
271  });
272  Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
273  auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
274 
275  // optionally extract subset of mesh axes
276  auto axes = adaptor.getAxes();
277  if (!axes.empty()) {
278  SmallVector<Value> subIndex;
279  for (auto axis : axes) {
280  subIndex.emplace_back(mIdx[axis]);
281  }
282  mIdx = std::move(subIndex);
283  }
284 
285  rewriter.replaceOp(op, mIdx);
286  return success();
287  }
288 };
289 
290 class ConvertProcessLinearIndexOp
291  : public OpConversionPattern<ProcessLinearIndexOp> {
292  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
293 
294 public:
296 
297  // Constructor accepting worldRank
298  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
299  MLIRContext *context, int64_t worldRank = -1)
300  : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
301 
302  LogicalResult
303  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
304  ConversionPatternRewriter &rewriter) const override {
305 
306  Location loc = op.getLoc();
307  if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
308  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
309  return success();
310  }
311 
312  // Otherwise call create mpi::CommRankOp
313  auto ctx = op.getContext();
314  Value commWorld =
315  rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
316  auto rank =
317  rewriter
318  .create<mpi::CommRankOp>(
319  loc,
320  TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
321  commWorld)
322  .getRank();
323  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
324  rank);
325  return success();
326  }
327 };
328 
329 struct ConvertNeighborsLinearIndicesOp
330  : public OpConversionPattern<NeighborsLinearIndicesOp> {
332 
333  LogicalResult
334  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
335  ConversionPatternRewriter &rewriter) const override {
336 
337  // Computes the neighbors indices along a split axis by simply
338  // adding/subtracting 1 to the current index in that dimension.
339  // Assigns -1 if neighbor is out of bounds.
340 
341  auto axes = adaptor.getSplitAxes();
342  // For now only single axis sharding is supported
343  if (axes.size() != 1)
344  return failure();
345 
346  Location loc = op.getLoc();
347  SymbolTableCollection symbolTableCollection;
348  auto meshOp = getMesh(op, symbolTableCollection);
349  auto mIdx = adaptor.getDevice();
350  auto orgIdx = mIdx[axes[0]];
351  SmallVector<Value> dims;
352  llvm::transform(
353  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
354  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
355  });
356  Value dimSz = dims[axes[0]];
357  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
358  Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
359  Value atBorder = rewriter.create<arith::CmpIOp>(
360  loc, arith::CmpIPredicate::sle, orgIdx,
361  rewriter.create<arith::ConstantIndexOp>(loc, 0));
362  auto down = rewriter.create<scf::IfOp>(
363  loc, atBorder,
364  [&](OpBuilder &builder, Location loc) {
365  builder.create<scf::YieldOp>(loc, minus1);
366  },
367  [&](OpBuilder &builder, Location loc) {
368  SmallVector<Value> tmp = mIdx;
369  tmp[axes[0]] =
370  rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
371  .getResult();
372  builder.create<scf::YieldOp>(
373  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
374  });
375  atBorder = rewriter.create<arith::CmpIOp>(
376  loc, arith::CmpIPredicate::sge, orgIdx,
377  rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
378  auto up = rewriter.create<scf::IfOp>(
379  loc, atBorder,
380  [&](OpBuilder &builder, Location loc) {
381  builder.create<scf::YieldOp>(loc, minus1);
382  },
383  [&](OpBuilder &builder, Location loc) {
384  SmallVector<Value> tmp = mIdx;
385  tmp[axes[0]] =
386  rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
387  builder.create<scf::YieldOp>(
388  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
389  });
390  rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
391  return success();
392  }
393 };
394 
395 struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
397 
398  LogicalResult
399  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
400  ConversionPatternRewriter &rewriter) const override {
401  auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
402  if (!sharding) {
403  return op->emitError()
404  << "Expected SharingOp as defining op for sharding"
405  << " but found " << adaptor.getSharding()[0].getDefiningOp();
406  }
407 
408  // Compute the sharded shape by applying the sharding to the input shape.
409  // If shardedDimsOffsets is not defined in the sharding, the shard shape is
410  // computed by dividing the dimension size by the number of shards in that
411  // dimension (which is given by the size of the mesh axes provided in
412  // split-axes). Odd elements get distributed to trailing shards. If a
413  // shardedDimsOffsets is provided, the shard shape is computed by
414  // subtracting the offset of the current shard from the offset of the next
415  // shard.
416 
417  Location loc = op.getLoc();
418  Type index = rewriter.getIndexType();
419 
420  // This is a 1:N conversion because the sharding op is a 1:3 conversion.
421  // The operands in the adaptor are a vector<ValeRange>. For dims and device
422  // we have a 1:1 conversion.
423  // For simpler access fill a vector with the dynamic dims.
424  SmallVector<Value> dynDims, dynDevice;
425  for (auto dim : adaptor.getDimsDynamic()) {
426  // type conversion should be 1:1 for ints
427  dynDims.emplace_back(llvm::getSingleElement(dim));
428  }
429  // same for device
430  for (auto device : adaptor.getDeviceDynamic()) {
431  dynDevice.emplace_back(llvm::getSingleElement(device));
432  }
433 
434  // To keep the code simple, convert dims/device to values when they are
435  // attributes. Count on canonicalization to fold static values.
436  SmallVector<Value> shape =
437  getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
438  SmallVector<Value> multiIdx =
439  getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
440 
441  // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
442  SymbolTableCollection symbolTableCollection;
443  auto meshOp = getMesh(sharding, symbolTableCollection);
444  // For now we only support static mesh shapes
445  if (ShapedType::isDynamicShape(meshOp.getShape()))
446  return failure();
447 
448  auto splitAxes = sharding.getSplitAxes().getAxes();
449  // shardedDimsOffsets are optional and might be Values (not attributes).
450  // Also, the shardId might be dynamic which means the position in the
451  // shardedDimsOffsets is not statically known. Create a tensor of the
452  // shardedDimsOffsets and later extract the offsets for computing the
453  // local shard-size.
454  Value shardedDimsOffs;
455  {
456  SmallVector<Value> tmp = getMixedAsValues(
457  rewriter, loc, sharding.getStaticShardedDimsOffsets(),
458  sharding.getDynamicShardedDimsOffsets(), index);
459  if (!tmp.empty())
460  shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
461  loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
462  }
463 
464  // With static mesh shape the sizes of the split axes are known.
465  // Hence the start/pos for each split axes in shardDimsOffsets can be
466  // computed statically.
467  int64_t pos = 0;
469  Value zero =
470  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
471  Value one =
472  rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
473 
474  // Iterate over the dimensions of the tensor shape, get their split Axes,
475  // and compute the sharded shape.
476  for (auto [i, dim] : llvm::enumerate(shape)) {
477  // Trailing dimensions might not be annotated.
478  if (i < splitAxes.size() && !splitAxes[i].empty()) {
479  auto axes = splitAxes[i];
480  // The current dimension might not be sharded.
481  // Create a value from the static position in shardDimsOffsets.
482  Value posVal =
483  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
484  // Get the index of the local shard in the mesh axis.
485  Value idx = multiIdx[axes[0]];
486  auto numShards =
487  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
488  if (shardedDimsOffs) {
489  // If sharded dims offsets are provided, use them to compute the
490  // sharded shape.
491  if (axes.size() > 1) {
492  return op->emitError() << "Only single axis sharding is "
493  << "supported for each dimension.";
494  }
495  idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
496  // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
497  Value off =
498  rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
499  idx = rewriter.create<arith::AddIOp>(loc, idx, one);
500  Value nextOff =
501  rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
502  Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
503  shardShape.emplace_back(sz);
504  } else {
505  Value numShardsVal = rewriter.create<arith::ConstantOp>(
506  loc, rewriter.getIndexAttr(numShards));
507  // Compute shard dim size by distributing odd elements to trailing
508  // shards:
509  // sz = dim / numShards
510  // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
511  Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
512  Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
513  sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
514  auto cond = rewriter.create<arith::CmpIOp>(
515  loc, arith::CmpIPredicate::sge, idx, sz1);
516  Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
517  sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
518  shardShape.emplace_back(sz);
519  }
520  pos += numShards + 1; // add one for the total size.
521  } // else no sharding if split axis is empty or no split axis
522  // If no size was added -> no sharding in this dimension.
523  if (shardShape.size() <= i)
524  shardShape.emplace_back(dim);
525  }
526  assert(shardShape.size() == shape.size());
527  rewriter.replaceOp(op, shardShape);
528  return success();
529  }
530 };
531 
532 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
534 
535  LogicalResult
536  matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
537  ConversionPatternRewriter &rewriter) const override {
538 
539  // The input/output memref is assumed to be in C memory order.
540  // Halos are exchanged as 2 blocks per dimension (one for each side: down
541  // and up). For each haloed dimension `d`, the exchanged blocks are
542  // expressed as multi-dimensional subviews. The subviews include potential
543  // halos of higher dimensions `dh > d`, no halos for the lower dimensions
544  // `dl < d` and for dimension `d` the currently exchanged halo only.
545  // By iterating form higher to lower dimensions this also updates the halos
546  // in the 'corners'.
547  // memref.subview is used to read and write the halo data from and to the
548  // local data. Because subviews and halos can have mixed dynamic and static
549  // shapes, OpFoldResults are used whenever possible.
550 
551  auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
552  adaptor.getHaloSizes(), rewriter);
553  if (haloSizes.empty()) {
554  // no halos -> nothing to do
555  rewriter.replaceOp(op, adaptor.getDestination());
556  return success();
557  }
558 
559  SymbolTableCollection symbolTableCollection;
560  Location loc = op.getLoc();
561 
562  // convert a OpFoldResult into a Value
563  auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
564  if (auto value = dyn_cast<Value>(v))
565  return value;
566  return rewriter.create<arith::ConstantOp>(
567  loc, rewriter.getIndexAttr(
568  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
569  };
570 
571  auto dest = adaptor.getDestination();
572  auto dstShape = cast<ShapedType>(dest.getType()).getShape();
573  Value array = dest;
574  if (isa<RankedTensorType>(array.getType())) {
575  // If the destination is a memref, we need to cast it to a tensor
576  auto tensorType = MemRefType::get(
577  dstShape, cast<ShapedType>(array.getType()).getElementType());
578  array =
579  rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
580  }
581  auto rank = cast<ShapedType>(array.getType()).getRank();
582  auto opSplitAxes = adaptor.getSplitAxes().getAxes();
583  auto mesh = adaptor.getMesh();
584  auto meshOp = getMesh(op, symbolTableCollection);
585  // subviews need Index values
586  for (auto &sz : haloSizes) {
587  if (auto value = dyn_cast<Value>(sz))
588  sz =
589  rewriter
590  .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
591  .getResult();
592  }
593 
594  // most of the offset/size/stride data is the same for all dims
595  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
596  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
597  SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
598  auto currHaloDim = -1; // halo sizes are provided for split dimensions only
599  // we need the actual shape to compute offsets and sizes
600  for (auto i = 0; i < rank; ++i) {
601  auto s = dstShape[i];
602  if (ShapedType::isDynamic(s))
603  shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
604  else
605  shape[i] = rewriter.getIndexAttr(s);
606 
607  if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
608  ++currHaloDim;
609  // the offsets for lower dim sstarts after their down halo
610  offsets[i] = haloSizes[currHaloDim * 2];
611 
612  // prepare shape and offsets of highest dim's halo exchange
613  Value _haloSz = rewriter.create<arith::AddIOp>(
614  loc, toValue(haloSizes[currHaloDim * 2]),
615  toValue(haloSizes[currHaloDim * 2 + 1]));
616  // the halo shape of lower dims exlude the halos
617  dimSizes[i] =
618  rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
619  .getResult();
620  } else {
621  dimSizes[i] = shape[i];
622  }
623  }
624 
625  auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
626  auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
627  auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
628  auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
629 
630  SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
631  rewriter.getIndexType());
632  auto myMultiIndex =
633  rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
634  .getResult();
635  // traverse all split axes from high to low dim
636  for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
637  auto splitAxes = opSplitAxes[dim];
638  if (splitAxes.empty())
639  continue;
640  assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
641  // Get the linearized ids of the neighbors (down and up) for the
642  // given split
643  auto tmp = rewriter
644  .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
645  splitAxes)
646  .getResults();
647  // MPI operates on i32...
648  Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
649  loc, rewriter.getI32Type(), tmp[0]),
650  rewriter.create<arith::IndexCastOp>(
651  loc, rewriter.getI32Type(), tmp[1])};
652 
653  auto lowerRecvOffset = rewriter.getIndexAttr(0);
654  auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
655  auto upperRecvOffset = rewriter.create<arith::SubIOp>(
656  loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
657  auto upperSendOffset = rewriter.create<arith::SubIOp>(
658  loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
659 
660  Value commWorld = rewriter.create<mpi::CommWorldOp>(
661  loc, mpi::CommType::get(op->getContext()));
662 
663  // Make sure we send/recv in a way that does not lead to a dead-lock.
664  // The current approach is by far not optimal, this should be at least
665  // be a red-black pattern or using MPI_sendrecv.
666  // Also, buffers should be re-used.
667  // Still using temporary contiguous buffers for MPI communication...
668  // Still yielding a "serialized" communication pattern...
669  auto genSendRecv = [&](bool upperHalo) {
670  auto orgOffset = offsets[dim];
671  dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
672  : haloSizes[currHaloDim * 2];
673  // Check if we need to send and/or receive
674  // Processes on the mesh borders have only one neighbor
675  auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
676  auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
677  auto hasFrom = rewriter.create<arith::CmpIOp>(
678  loc, arith::CmpIPredicate::sge, from, zero);
679  auto hasTo = rewriter.create<arith::CmpIOp>(
680  loc, arith::CmpIPredicate::sge, to, zero);
681  auto buffer = rewriter.create<memref::AllocOp>(
682  loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
683  // if has neighbor: copy halo data from array to buffer and send
684  rewriter.create<scf::IfOp>(
685  loc, hasTo, [&](OpBuilder &builder, Location loc) {
686  offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
687  : OpFoldResult(upperSendOffset);
688  auto subview = builder.create<memref::SubViewOp>(
689  loc, array, offsets, dimSizes, strides);
690  builder.create<memref::CopyOp>(loc, subview, buffer);
691  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
692  commWorld);
693  builder.create<scf::YieldOp>(loc);
694  });
695  // if has neighbor: receive halo data into buffer and copy to array
696  rewriter.create<scf::IfOp>(
697  loc, hasFrom, [&](OpBuilder &builder, Location loc) {
698  offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
699  : OpFoldResult(lowerRecvOffset);
700  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
701  commWorld);
702  auto subview = builder.create<memref::SubViewOp>(
703  loc, array, offsets, dimSizes, strides);
704  builder.create<memref::CopyOp>(loc, buffer, subview);
705  builder.create<scf::YieldOp>(loc);
706  });
707  rewriter.create<memref::DeallocOp>(loc, buffer);
708  offsets[dim] = orgOffset;
709  };
710 
711  auto doSendRecv = [&](int upOrDown) {
712  OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
713  Value haloSz = dyn_cast<Value>(v);
714  if (!haloSz)
715  haloSz = rewriter.create<arith::ConstantOp>(
716  loc, rewriter.getI32IntegerAttr(
717  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
718  auto hasSize = rewriter.create<arith::CmpIOp>(
719  loc, arith::CmpIPredicate::sgt, haloSz, zero);
720  rewriter.create<scf::IfOp>(loc, hasSize,
721  [&](OpBuilder &builder, Location loc) {
722  genSendRecv(upOrDown > 0);
723  builder.create<scf::YieldOp>(loc);
724  });
725  };
726 
727  doSendRecv(0);
728  doSendRecv(1);
729 
730  // the shape for lower dims include higher dims' halos
731  dimSizes[dim] = shape[dim];
732  // -> the offset for higher dims is always 0
733  offsets[dim] = rewriter.getIndexAttr(0);
734  // on to next halo
735  --currHaloDim;
736  }
737 
738  if (isa<MemRefType>(op.getResult().getType())) {
739  rewriter.replaceOp(op, array);
740  } else {
741  assert(isa<RankedTensorType>(op.getResult().getType()));
742  rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
743  loc, op.getResult().getType(), array,
744  /*restrict=*/true, /*writable=*/true));
745  }
746  return success();
747  }
748 };
749 
750 struct ConvertMeshToMPIPass
751  : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
752  using Base::Base;
753 
754  /// Run the dialect converter on the module.
755  void runOnOperation() override {
756  uint64_t worldRank = -1;
757  // Try to get DLTI attribute for MPI:comm_world_rank
758  // If found, set worldRank to the value of the attribute.
759  {
760  auto dltiAttr =
761  dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
762  if (succeeded(dltiAttr)) {
763  if (!isa<IntegerAttr>(dltiAttr.value())) {
764  getOperation()->emitError()
765  << "Expected an integer attribute for MPI:comm_world_rank";
766  return signalPassFailure();
767  }
768  worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
769  }
770  }
771 
772  auto *ctxt = &getContext();
774  ConversionTarget target(getContext());
775 
776  // Define a type converter to convert mesh::ShardingType,
777  // mostly for use in return operations.
778  TypeConverter typeConverter;
779  typeConverter.addConversion([](Type type) { return type; });
780 
781  // convert mesh::ShardingType to a tuple of RankedTensorTypes
782  typeConverter.addConversion(
783  [](ShardingType type,
784  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
785  auto i16 = IntegerType::get(type.getContext(), 16);
786  auto i64 = IntegerType::get(type.getContext(), 64);
787  std::array<int64_t, 2> shp = {ShapedType::kDynamic,
788  ShapedType::kDynamic};
789  results.emplace_back(RankedTensorType::get(shp, i16));
790  results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
791  results.emplace_back(RankedTensorType::get(shp, i64));
792  return success();
793  });
794 
795  // To 'extract' components, a UnrealizedConversionCastOp is expected
796  // to define the input
797  typeConverter.addTargetMaterialization(
798  [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
799  Location loc) {
800  // Expecting a single input.
801  if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
802  return SmallVector<Value>();
803  auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
804  // Expecting an UnrealizedConversionCastOp.
805  if (!castOp)
806  return SmallVector<Value>();
807  // Fill a vector with elements of the tuple/castOp.
808  SmallVector<Value> results;
809  for (auto oprnd : castOp.getInputs()) {
810  if (!isa<RankedTensorType>(oprnd.getType()))
811  return SmallVector<Value>();
812  results.emplace_back(oprnd);
813  }
814  return results;
815  });
816 
817  // No mesh dialect should left after conversion...
818  target.addIllegalDialect<mesh::MeshDialect>();
819  // ...except the global MeshOp
820  target.addLegalOp<mesh::MeshOp>();
821  // Allow all the stuff that our patterns will convert to
822  target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
823  arith::ArithDialect, tensor::TensorDialect,
824  bufferization::BufferizationDialect,
825  linalg::LinalgDialect, memref::MemRefDialect>();
826  // Make sure the function signature, calls etc. are legal
827  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
828  return typeConverter.isSignatureLegal(op.getFunctionType());
829  });
830  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
831  [&](Operation *op) { return typeConverter.isLegal(op); });
832 
833  patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
834  ConvertProcessMultiIndexOp, ConvertGetShardingOp,
835  ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
836  // ConvertProcessLinearIndexOp accepts an optional worldRank
837  patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
838 
839  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
840  patterns, typeConverter);
843 
844  (void)applyPartialConversion(getOperation(), target, std::move(patterns));
845  }
846 };
847 
848 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:193
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IndexType getIndexType()
Definition: Builders.cpp:51
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:338
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition: DLTI.cpp:527
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:153
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:128
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.