MLIR  21.0.0git
MeshToMPI.cpp
Go to the documentation of this file.
1 //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation of Mesh communication ops tp MPI ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/Dialect/DLTI/DLTI.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/SymbolTable.h"
35 
36 #define DEBUG_TYPE "mesh-to-mpi"
37 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 using namespace mesh;
46 
47 namespace {
48 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
49 /// provided type.
50 static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
52  ValueRange dynamics,
53  Type type = Type()) {
54  SmallVector<Value> values;
55  auto dyn = dynamics.begin();
56  Type i64 = b.getI64Type();
57  if (!type)
58  type = i64;
59  assert((i64 == type || b.getIndexType() == type) &&
60  "expected an i64 or an intex type");
61  for (auto s : statics) {
62  if (s == ShapedType::kDynamic) {
63  values.emplace_back(*(dyn++));
64  } else {
65  TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
66  values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
67  }
68  }
69  return values;
70 }
71 
72 /// Create operations converting a linear index to a multi-dimensional index.
73 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
74  Value linearIndex,
75  ValueRange dimensions) {
76  int n = dimensions.size();
77  SmallVector<Value> multiIndex(n);
78 
79  for (int i = n - 1; i >= 0; --i) {
80  multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
81  if (i > 0)
82  linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
83  }
84 
85  return multiIndex;
86 }
87 
88 /// Create operations converting a multi-dimensional index to a linear index.
89 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
90  ValueRange dimensions) {
91 
92  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
93  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
94 
95  for (int i = multiIndex.size() - 1; i >= 0; --i) {
96  Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
97  linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
98  stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
99  }
100 
101  return linearIndex;
102 }
103 
104 /// Replace GetShardingOp with related/dependent ShardingOp.
105 struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
107 
108  LogicalResult
109  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
112  if (!shardOp)
113  return failure();
114  auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
115  if (!shardingOp)
116  return failure();
117 
118  rewriter.replaceOp(op, shardingOp.getResult());
119  return success();
120  }
121 };
122 
123 /// Convert a sharding op to a tuple of tensors of its components
124 /// (SplitAxes, HaloSizes, ShardedDimsOffsets)
125 /// as defined by type converter.
126 struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
128 
129  LogicalResult
130  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
131  ConversionPatternRewriter &rewriter) const override {
132  auto splitAxes = op.getSplitAxes().getAxes();
133  int64_t maxNAxes = 0;
134  for (auto axes : splitAxes)
135  maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
136 
137  // To hold the split axes, create empty 2d tensor with shape
138  // {splitAxes.size(), max-size-of-split-groups}.
139  // Set trailing elements for smaller split-groups to -1.
140  Location loc = op.getLoc();
141  auto i16 = rewriter.getI16Type();
142  auto i64 = rewriter.getI64Type();
143  std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
144  maxNAxes};
145  Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
146  auto attr = IntegerAttr::get(i16, -1);
147  Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
148  resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
149  .getResult(0);
150 
151  // explicitly write values into tensor row by row
152  std::array<int64_t, 2> strides = {1, 1};
153  int64_t nSplits = 0;
154  ValueRange empty = {};
155  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
156  int64_t size = axes.size();
157  if (size > 0)
158  ++nSplits;
159  std::array<int64_t, 2> offs = {(int64_t)i, 0};
160  std::array<int64_t, 2> sizes = {1, size};
161  auto tensorType = RankedTensorType::get({size}, i16);
162  auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
163  auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
164  resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
165  loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
166  }
167 
168  // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
169  // Store the halo sizes in the tensor.
170  SmallVector<Value> haloSizes =
171  getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
172  adaptor.getDynamicHaloSizes());
173  auto type = RankedTensorType::get({nSplits, 2}, i64);
174  Value resHaloSizes =
175  haloSizes.empty()
176  ? rewriter
177  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
178  i64)
179  .getResult()
180  : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
181  .getResult();
182 
183  // To hold sharded dims offsets, create Tensor with shape {nSplits,
184  // maxSplitSize+1}. Store the offsets in the tensor but set trailing
185  // elements for smaller split-groups to -1. Computing the max size of the
186  // split groups needs using collectiveProcessGroupSize (which needs the
187  // MeshOp)
188  Value resOffsets;
189  if (adaptor.getStaticShardedDimsOffsets().empty()) {
190  resOffsets = rewriter.create<tensor::EmptyOp>(
191  loc, std::array<int64_t, 2>{0, 0}, i64);
192  } else {
193  SymbolTableCollection symbolTableCollection;
194  auto meshOp = getMesh(op, symbolTableCollection);
195  int64_t maxSplitSize = 0;
196  for (auto axes : splitAxes) {
197  int64_t splitSize =
198  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
199  assert(splitSize != ShapedType::kDynamic);
200  maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
201  }
202  assert(maxSplitSize);
203  ++maxSplitSize; // add one for the total size
204 
205  resOffsets = rewriter.create<tensor::EmptyOp>(
206  loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
207  Value zero = rewriter.create<arith::ConstantOp>(
208  loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
209  resOffsets =
210  rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
211  SmallVector<Value> offsets =
212  getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
213  adaptor.getDynamicShardedDimsOffsets());
214  int64_t curr = 0;
215  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
216  int64_t splitSize =
217  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
218  assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
219  ++splitSize; // add one for the total size
220  ArrayRef<Value> values(&offsets[curr], splitSize);
221  Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
222  std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
223  std::array<int64_t, 2> sizes = {1, splitSize};
224  resOffsets = rewriter.create<tensor::InsertSliceOp>(
225  loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
226  curr += splitSize;
227  }
228  }
229 
230  // return a tuple of tensors as defined by type converter
231  SmallVector<Type> resTypes;
232  if (failed(getTypeConverter()->convertType(op.getResult().getType(),
233  resTypes)))
234  return failure();
235 
236  resSplitAxes =
237  rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
238  resHaloSizes =
239  rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
240  resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
241 
242  rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
243  op, TupleType::get(op.getContext(), resTypes),
244  ValueRange{resSplitAxes, resHaloSizes, resOffsets});
245 
246  return success();
247  }
248 };
249 
250 struct ConvertProcessMultiIndexOp
251  : public OpConversionPattern<ProcessMultiIndexOp> {
253 
254  LogicalResult
255  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257 
258  // Currently converts its linear index to a multi-dimensional index.
259 
260  SymbolTableCollection symbolTableCollection;
261  Location loc = op.getLoc();
262  auto meshOp = getMesh(op, symbolTableCollection);
263  // For now we only support static mesh shapes
264  if (ShapedType::isDynamicShape(meshOp.getShape()))
265  return failure();
266 
267  SmallVector<Value> dims;
268  llvm::transform(
269  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
271  });
272  Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
273  auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
274 
275  // optionally extract subset of mesh axes
276  auto axes = adaptor.getAxes();
277  if (!axes.empty()) {
278  SmallVector<Value> subIndex;
279  for (auto axis : axes) {
280  subIndex.emplace_back(mIdx[axis]);
281  }
282  mIdx = std::move(subIndex);
283  }
284 
285  rewriter.replaceOp(op, mIdx);
286  return success();
287  }
288 };
289 
290 class ConvertProcessLinearIndexOp
291  : public OpConversionPattern<ProcessLinearIndexOp> {
292  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
293 
294 public:
296 
297  // Constructor accepting worldRank
298  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
299  MLIRContext *context, int64_t worldRank = -1)
300  : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
301 
302  LogicalResult
303  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
304  ConversionPatternRewriter &rewriter) const override {
305 
306  Location loc = op.getLoc();
307  if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
308  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
309  return success();
310  }
311 
312  // Otherwise call create mpi::CommRankOp
313  auto rank = rewriter
314  .create<mpi::CommRankOp>(
315  loc, TypeRange{mpi::RetvalType::get(op->getContext()),
316  rewriter.getI32Type()})
317  .getRank();
318  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
319  rank);
320  return success();
321  }
322 };
323 
324 struct ConvertNeighborsLinearIndicesOp
325  : public OpConversionPattern<NeighborsLinearIndicesOp> {
327 
328  LogicalResult
329  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override {
331 
332  // Computes the neighbors indices along a split axis by simply
333  // adding/subtracting 1 to the current index in that dimension.
334  // Assigns -1 if neighbor is out of bounds.
335 
336  auto axes = adaptor.getSplitAxes();
337  // For now only single axis sharding is supported
338  if (axes.size() != 1)
339  return failure();
340 
341  Location loc = op.getLoc();
342  SymbolTableCollection symbolTableCollection;
343  auto meshOp = getMesh(op, symbolTableCollection);
344  auto mIdx = adaptor.getDevice();
345  auto orgIdx = mIdx[axes[0]];
346  SmallVector<Value> dims;
347  llvm::transform(
348  meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
349  return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
350  });
351  Value dimSz = dims[axes[0]];
352  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
353  Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
354  Value atBorder = rewriter.create<arith::CmpIOp>(
355  loc, arith::CmpIPredicate::sle, orgIdx,
356  rewriter.create<arith::ConstantIndexOp>(loc, 0));
357  auto down = rewriter.create<scf::IfOp>(
358  loc, atBorder,
359  [&](OpBuilder &builder, Location loc) {
360  builder.create<scf::YieldOp>(loc, minus1);
361  },
362  [&](OpBuilder &builder, Location loc) {
363  SmallVector<Value> tmp = mIdx;
364  tmp[axes[0]] =
365  rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
366  .getResult();
367  builder.create<scf::YieldOp>(
368  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
369  });
370  atBorder = rewriter.create<arith::CmpIOp>(
371  loc, arith::CmpIPredicate::sge, orgIdx,
372  rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
373  auto up = rewriter.create<scf::IfOp>(
374  loc, atBorder,
375  [&](OpBuilder &builder, Location loc) {
376  builder.create<scf::YieldOp>(loc, minus1);
377  },
378  [&](OpBuilder &builder, Location loc) {
379  SmallVector<Value> tmp = mIdx;
380  tmp[axes[0]] =
381  rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
382  builder.create<scf::YieldOp>(
383  loc, multiToLinearIndex(loc, rewriter, tmp, dims));
384  });
385  rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
386  return success();
387  }
388 };
389 
390 struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
392 
393  LogicalResult
394  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
395  ConversionPatternRewriter &rewriter) const override {
396  auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
397  if (!sharding) {
398  return op->emitError()
399  << "Expected SharingOp as defining op for sharding"
400  << " but found " << adaptor.getSharding()[0].getDefiningOp();
401  }
402 
403  // Compute the sharded shape by applying the sharding to the input shape.
404  // If shardedDimsOffsets is not defined in the sharding, the shard shape is
405  // computed by dividing the dimension size by the number of shards in that
406  // dimension (which is given by the size of the mesh axes provided in
407  // split-axes). Odd elements get distributed to trailing shards. If a
408  // shardedDimsOffsets is provided, the shard shape is computed by
409  // subtracting the offset of the current shard from the offset of the next
410  // shard.
411 
412  Location loc = op.getLoc();
413  Type index = rewriter.getIndexType();
414 
415  // This is a 1:N conversion because the sharding op is a 1:3 conversion.
416  // The operands in the adaptor are a vector<ValeRange>. For dims and device
417  // we have a 1:1 conversion.
418  // For simpler access fill a vector with the dynamic dims.
419  SmallVector<Value> dynDims, dynDevice;
420  for (auto dim : adaptor.getDimsDynamic()) {
421  // type conversion should be 1:1 for ints
422  assert(dim.size() == 1);
423  dynDims.emplace_back(dim[0]);
424  }
425  // same for device
426  for (auto device : adaptor.getDeviceDynamic()) {
427  assert(device.size() == 1);
428  dynDevice.emplace_back(device[0]);
429  }
430 
431  // To keep the code simple, convert dims/device to values when they are
432  // attributes. Count on canonicalization to fold static values.
433  SmallVector<Value> shape =
434  getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
435  SmallVector<Value> multiIdx =
436  getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
437 
438  // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
439  SymbolTableCollection symbolTableCollection;
440  auto meshOp = getMesh(sharding, symbolTableCollection);
441  // For now we only support static mesh shapes
442  if (ShapedType::isDynamicShape(meshOp.getShape()))
443  return failure();
444 
445  auto splitAxes = sharding.getSplitAxes().getAxes();
446  // shardedDimsOffsets are optional and might be Values (not attributes).
447  // Also, the shardId might be dynamic which means the position in the
448  // shardedDimsOffsets is not statically known. Create a tensor of the
449  // shardedDimsOffsets and later extract the offsets for computing the
450  // local shard-size.
451  Value shardedDimsOffs;
452  {
453  SmallVector<Value> tmp = getMixedAsValues(
454  rewriter, loc, sharding.getStaticShardedDimsOffsets(),
455  sharding.getDynamicShardedDimsOffsets(), index);
456  if (!tmp.empty())
457  shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
458  loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
459  }
460 
461  // With static mesh shape the sizes of the split axes are known.
462  // Hence the start/pos for each split axes in shardDimsOffsets can be
463  // computed statically.
464  int64_t pos = 0;
466  Value zero =
467  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
468  Value one =
469  rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
470 
471  // Iterate over the dimensions of the tensor shape, get their split Axes,
472  // and compute the sharded shape.
473  for (auto [i, dim] : llvm::enumerate(shape)) {
474  // Trailing dimensions might not be annotated.
475  if (i < splitAxes.size() && !splitAxes[i].empty()) {
476  auto axes = splitAxes[i];
477  // The current dimension might not be sharded.
478  // Create a value from the static position in shardDimsOffsets.
479  Value posVal =
480  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
481  // Get the index of the local shard in the mesh axis.
482  Value idx = multiIdx[axes[0]];
483  auto numShards =
484  collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
485  if (shardedDimsOffs) {
486  // If sharded dims offsets are provided, use them to compute the
487  // sharded shape.
488  if (axes.size() > 1) {
489  return op->emitError() << "Only single axis sharding is "
490  << "supported for each dimension.";
491  }
492  idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
493  // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
494  Value off =
495  rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
496  idx = rewriter.create<arith::AddIOp>(loc, idx, one);
497  Value nextOff =
498  rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
499  Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
500  shardShape.emplace_back(sz);
501  } else {
502  Value numShardsVal = rewriter.create<arith::ConstantOp>(
503  loc, rewriter.getIndexAttr(numShards));
504  // Compute shard dim size by distributing odd elements to trailing
505  // shards:
506  // sz = dim / numShards
507  // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
508  Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
509  Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
510  sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
511  auto cond = rewriter.create<arith::CmpIOp>(
512  loc, arith::CmpIPredicate::sge, idx, sz1);
513  Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
514  sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
515  shardShape.emplace_back(sz);
516  }
517  pos += numShards + 1; // add one for the total size.
518  } // else no sharding if split axis is empty or no split axis
519  // If no size was added -> no sharding in this dimension.
520  if (shardShape.size() <= i)
521  shardShape.emplace_back(dim);
522  }
523  assert(shardShape.size() == shape.size());
524  rewriter.replaceOp(op, shardShape);
525  return success();
526  }
527 };
528 
529 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
531 
532  LogicalResult
533  matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
534  ConversionPatternRewriter &rewriter) const override {
535 
536  // The input/output memref is assumed to be in C memory order.
537  // Halos are exchanged as 2 blocks per dimension (one for each side: down
538  // and up). For each haloed dimension `d`, the exchanged blocks are
539  // expressed as multi-dimensional subviews. The subviews include potential
540  // halos of higher dimensions `dh > d`, no halos for the lower dimensions
541  // `dl < d` and for dimension `d` the currently exchanged halo only.
542  // By iterating form higher to lower dimensions this also updates the halos
543  // in the 'corners'.
544  // memref.subview is used to read and write the halo data from and to the
545  // local data. Because subviews and halos can have mixed dynamic and static
546  // shapes, OpFoldResults are used whenever possible.
547 
548  auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
549  adaptor.getHaloSizes(), rewriter);
550  if (haloSizes.empty()) {
551  // no halos -> nothing to do
552  rewriter.replaceOp(op, adaptor.getDestination());
553  return success();
554  }
555 
556  SymbolTableCollection symbolTableCollection;
557  Location loc = op.getLoc();
558 
559  // convert a OpFoldResult into a Value
560  auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
561  if (auto value = dyn_cast<Value>(v))
562  return value;
563  return rewriter.create<arith::ConstantOp>(
564  loc, rewriter.getIndexAttr(
565  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
566  };
567 
568  auto dest = adaptor.getDestination();
569  auto dstShape = cast<ShapedType>(dest.getType()).getShape();
570  Value array = dest;
571  if (isa<RankedTensorType>(array.getType())) {
572  // If the destination is a memref, we need to cast it to a tensor
573  auto tensorType = MemRefType::get(
574  dstShape, cast<ShapedType>(array.getType()).getElementType());
575  array =
576  rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
577  }
578  auto rank = cast<ShapedType>(array.getType()).getRank();
579  auto opSplitAxes = adaptor.getSplitAxes().getAxes();
580  auto mesh = adaptor.getMesh();
581  auto meshOp = getMesh(op, symbolTableCollection);
582  // subviews need Index values
583  for (auto &sz : haloSizes) {
584  if (auto value = dyn_cast<Value>(sz))
585  sz =
586  rewriter
587  .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
588  .getResult();
589  }
590 
591  // most of the offset/size/stride data is the same for all dims
592  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
593  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
594  SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
595  auto currHaloDim = -1; // halo sizes are provided for split dimensions only
596  // we need the actual shape to compute offsets and sizes
597  for (auto i = 0; i < rank; ++i) {
598  auto s = dstShape[i];
599  if (ShapedType::isDynamic(s))
600  shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
601  else
602  shape[i] = rewriter.getIndexAttr(s);
603 
604  if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
605  ++currHaloDim;
606  // the offsets for lower dim sstarts after their down halo
607  offsets[i] = haloSizes[currHaloDim * 2];
608 
609  // prepare shape and offsets of highest dim's halo exchange
610  Value _haloSz = rewriter.create<arith::AddIOp>(
611  loc, toValue(haloSizes[currHaloDim * 2]),
612  toValue(haloSizes[currHaloDim * 2 + 1]));
613  // the halo shape of lower dims exlude the halos
614  dimSizes[i] =
615  rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
616  .getResult();
617  } else {
618  dimSizes[i] = shape[i];
619  }
620  }
621 
622  auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
623  auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
624  auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
625  auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
626 
627  SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
628  rewriter.getIndexType());
629  auto myMultiIndex =
630  rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
631  .getResult();
632  // traverse all split axes from high to low dim
633  for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
634  auto splitAxes = opSplitAxes[dim];
635  if (splitAxes.empty())
636  continue;
637  assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
638  // Get the linearized ids of the neighbors (down and up) for the
639  // given split
640  auto tmp = rewriter
641  .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
642  splitAxes)
643  .getResults();
644  // MPI operates on i32...
645  Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
646  loc, rewriter.getI32Type(), tmp[0]),
647  rewriter.create<arith::IndexCastOp>(
648  loc, rewriter.getI32Type(), tmp[1])};
649 
650  auto lowerRecvOffset = rewriter.getIndexAttr(0);
651  auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
652  auto upperRecvOffset = rewriter.create<arith::SubIOp>(
653  loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
654  auto upperSendOffset = rewriter.create<arith::SubIOp>(
655  loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
656 
657  // Make sure we send/recv in a way that does not lead to a dead-lock.
658  // The current approach is by far not optimal, this should be at least
659  // be a red-black pattern or using MPI_sendrecv.
660  // Also, buffers should be re-used.
661  // Still using temporary contiguous buffers for MPI communication...
662  // Still yielding a "serialized" communication pattern...
663  auto genSendRecv = [&](bool upperHalo) {
664  auto orgOffset = offsets[dim];
665  dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
666  : haloSizes[currHaloDim * 2];
667  // Check if we need to send and/or receive
668  // Processes on the mesh borders have only one neighbor
669  auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
670  auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
671  auto hasFrom = rewriter.create<arith::CmpIOp>(
672  loc, arith::CmpIPredicate::sge, from, zero);
673  auto hasTo = rewriter.create<arith::CmpIOp>(
674  loc, arith::CmpIPredicate::sge, to, zero);
675  auto buffer = rewriter.create<memref::AllocOp>(
676  loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
677  // if has neighbor: copy halo data from array to buffer and send
678  rewriter.create<scf::IfOp>(
679  loc, hasTo, [&](OpBuilder &builder, Location loc) {
680  offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
681  : OpFoldResult(upperSendOffset);
682  auto subview = builder.create<memref::SubViewOp>(
683  loc, array, offsets, dimSizes, strides);
684  builder.create<memref::CopyOp>(loc, subview, buffer);
685  builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
686  builder.create<scf::YieldOp>(loc);
687  });
688  // if has neighbor: receive halo data into buffer and copy to array
689  rewriter.create<scf::IfOp>(
690  loc, hasFrom, [&](OpBuilder &builder, Location loc) {
691  offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
692  : OpFoldResult(lowerRecvOffset);
693  builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
694  auto subview = builder.create<memref::SubViewOp>(
695  loc, array, offsets, dimSizes, strides);
696  builder.create<memref::CopyOp>(loc, buffer, subview);
697  builder.create<scf::YieldOp>(loc);
698  });
699  rewriter.create<memref::DeallocOp>(loc, buffer);
700  offsets[dim] = orgOffset;
701  };
702 
703  auto doSendRecv = [&](int upOrDown) {
704  OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
705  Value haloSz = dyn_cast<Value>(v);
706  if (!haloSz)
707  haloSz = rewriter.create<arith::ConstantOp>(
708  loc, rewriter.getI32IntegerAttr(
709  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
710  auto hasSize = rewriter.create<arith::CmpIOp>(
711  loc, arith::CmpIPredicate::sgt, haloSz, zero);
712  rewriter.create<scf::IfOp>(loc, hasSize,
713  [&](OpBuilder &builder, Location loc) {
714  genSendRecv(upOrDown > 0);
715  builder.create<scf::YieldOp>(loc);
716  });
717  };
718 
719  doSendRecv(0);
720  doSendRecv(1);
721 
722  // the shape for lower dims include higher dims' halos
723  dimSizes[dim] = shape[dim];
724  // -> the offset for higher dims is always 0
725  offsets[dim] = rewriter.getIndexAttr(0);
726  // on to next halo
727  --currHaloDim;
728  }
729 
730  if (isa<MemRefType>(op.getResult().getType())) {
731  rewriter.replaceOp(op, array);
732  } else {
733  assert(isa<RankedTensorType>(op.getResult().getType()));
734  rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
735  loc, op.getResult().getType(), array,
736  /*restrict=*/true, /*writable=*/true));
737  }
738  return success();
739  }
740 };
741 
742 struct ConvertMeshToMPIPass
743  : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
744  using Base::Base;
745 
746  /// Run the dialect converter on the module.
747  void runOnOperation() override {
748  uint64_t worldRank = -1;
749  // Try to get DLTI attribute for MPI:comm_world_rank
750  // If found, set worldRank to the value of the attribute.
751  {
752  auto dltiAttr =
753  dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
754  if (succeeded(dltiAttr)) {
755  if (!isa<IntegerAttr>(dltiAttr.value())) {
756  getOperation()->emitError()
757  << "Expected an integer attribute for MPI:comm_world_rank";
758  return signalPassFailure();
759  }
760  worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
761  }
762  }
763 
764  auto *ctxt = &getContext();
766  ConversionTarget target(getContext());
767 
768  // Define a type converter to convert mesh::ShardingType,
769  // mostly for use in return operations.
770  TypeConverter typeConverter;
771  typeConverter.addConversion([](Type type) { return type; });
772 
773  // convert mesh::ShardingType to a tuple of RankedTensorTypes
774  typeConverter.addConversion(
775  [](ShardingType type,
776  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
777  auto i16 = IntegerType::get(type.getContext(), 16);
778  auto i64 = IntegerType::get(type.getContext(), 64);
779  std::array<int64_t, 2> shp = {ShapedType::kDynamic,
780  ShapedType::kDynamic};
781  results.emplace_back(RankedTensorType::get(shp, i16));
782  results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
783  results.emplace_back(RankedTensorType::get(shp, i64));
784  return success();
785  });
786 
787  // To 'extract' components, a UnrealizedConversionCastOp is expected
788  // to define the input
789  typeConverter.addTargetMaterialization(
790  [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
791  Location loc) {
792  // Expecting a single input.
793  if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
794  return SmallVector<Value>();
795  auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
796  // Expecting an UnrealizedConversionCastOp.
797  if (!castOp)
798  return SmallVector<Value>();
799  // Fill a vector with elements of the tuple/castOp.
800  SmallVector<Value> results;
801  for (auto oprnd : castOp.getInputs()) {
802  if (!isa<RankedTensorType>(oprnd.getType()))
803  return SmallVector<Value>();
804  results.emplace_back(oprnd);
805  }
806  return results;
807  });
808 
809  // No mesh dialect should left after conversion...
810  target.addIllegalDialect<mesh::MeshDialect>();
811  // ...except the global MeshOp
812  target.addLegalOp<mesh::MeshOp>();
813  // Allow all the stuff that our patterns will convert to
814  target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
815  arith::ArithDialect, tensor::TensorDialect,
816  bufferization::BufferizationDialect,
817  linalg::LinalgDialect, memref::MemRefDialect>();
818  // Make sure the function signature, calls etc. are legal
819  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
820  return typeConverter.isSignatureLegal(op.getFunctionType());
821  });
822  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
823  [&](Operation *op) { return typeConverter.isLegal(op); });
824 
825  patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
826  ConvertProcessMultiIndexOp, ConvertGetShardingOp,
827  ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
828  // ConvertProcessLinearIndexOp accepts an optional worldRank
829  patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
830 
831  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
832  patterns, typeConverter);
835 
836  (void)applyPartialConversion(getOperation(), target, std::move(patterns));
837  }
838 };
839 
840 } // namespace
static MLIRContext * getContext(OpFoldResult val)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: MeshOps.cpp:193
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI16Type()
Definition: Builders.cpp:61
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IndexType getIndexType()
Definition: Builders.cpp:51
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:338
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
Definition: DLTI.cpp:527
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:151
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.