MLIR  22.0.0git
ShardToMPI.cpp
Go to the documentation of this file.
1 //===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation of Shard communication ops to MPI ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinTypes.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/IR/SymbolTable.h"
37 
38 #define DEBUG_TYPE "shard-to-mpi"
39 
40 namespace mlir {
41 #define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
42 #include "mlir/Conversion/Passes.h.inc"
43 } // namespace mlir
44 
45 using namespace mlir;
46 using namespace shard;
47 
48 namespace {
49 /// Converts a vector of OpFoldResults (ints) into vector of Values of the
50 /// provided type.
53  ValueRange dynamics,
54  Type type = Type()) {
55  SmallVector<Value> values;
56  auto dyn = dynamics.begin();
57  Type i64 = b.getI64Type();
58  if (!type)
59  type = i64;
60  assert((i64 == type || b.getIndexType() == type) &&
61  "expected an i64 or an intex type");
62  for (auto s : statics) {
63  if (s == ShapedType::kDynamic) {
64  values.emplace_back(*(dyn++));
65  } else {
66  TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
67  values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
68  }
69  }
70  return values;
71 }
72 
73 /// Create operations converting a linear index to a multi-dimensional index.
74 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
75  Value linearIndex,
76  ValueRange dimensions) {
77  int n = dimensions.size();
78  SmallVector<Value> multiIndex(n);
79 
80  for (int i = n - 1; i >= 0; --i) {
81  multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
82  if (i > 0)
83  linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
84  }
85 
86  return multiIndex;
87 }
88 
89 /// Create operations converting a multi-dimensional index to a linear index.
90 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
91  ValueRange dimensions) {
92 
93  Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
94  Value stride = arith::ConstantIndexOp::create(b, loc, 1);
95 
96  for (int i = multiIndex.size() - 1; i >= 0; --i) {
97  Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
98  linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
99  stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
100  }
101 
102  return linearIndex;
103 }
104 
105 /// Replace GetShardingOp with related/dependent ShardingOp.
106 struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
108 
109  LogicalResult
110  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const override {
112  auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
113  if (!shardOp)
114  return failure();
115  auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
116  if (!shardingOp)
117  return failure();
118 
119  rewriter.replaceOp(op, shardingOp.getResult());
120  return success();
121  }
122 };
123 
124 /// Convert a sharding op to a tuple of tensors of its components
125 /// (SplitAxes, HaloSizes, ShardedDimsOffsets)
126 /// as defined by type converter.
127 struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
129 
130  LogicalResult
131  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
132  ConversionPatternRewriter &rewriter) const override {
133  auto splitAxes = op.getSplitAxes().getAxes();
134  int64_t maxNAxes = 0;
135  for (auto axes : splitAxes)
136  maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
137 
138  // To hold the split axes, create empty 2d tensor with shape
139  // {splitAxes.size(), max-size-of-split-groups}.
140  // Set trailing elements for smaller split-groups to -1.
141  Location loc = op.getLoc();
142  auto i16 = rewriter.getI16Type();
143  auto i64 = rewriter.getI64Type();
144  std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
145  maxNAxes};
146  Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
147  auto attr = IntegerAttr::get(i16, -1);
148  Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
149  resSplitAxes =
150  linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
151  .getResult(0);
152 
153  // explicitly write values into tensor row by row
154  std::array<int64_t, 2> strides = {1, 1};
155  int64_t nSplits = 0;
156  ValueRange empty = {};
157  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
158  int64_t size = axes.size();
159  if (size > 0)
160  ++nSplits;
161  std::array<int64_t, 2> offs = {(int64_t)i, 0};
162  std::array<int64_t, 2> sizes = {1, size};
163  auto tensorType = RankedTensorType::get({size}, i16);
164  auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
165  auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
166  resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
167  resSplitAxes, empty, empty,
168  empty, offs, sizes, strides);
169  }
170 
171  // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
172  // Store the halo sizes in the tensor.
173  SmallVector<Value> haloSizes =
174  getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
175  adaptor.getDynamicHaloSizes());
176  auto type = RankedTensorType::get({nSplits, 2}, i64);
177  Value resHaloSizes =
178  haloSizes.empty()
179  ? tensor::EmptyOp::create(rewriter, loc,
180  std::array<int64_t, 2>{0, 0}, i64)
181  .getResult()
182  : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
183  .getResult();
184 
185  // To hold sharded dims offsets, create Tensor with shape {nSplits,
186  // maxSplitSize+1}. Store the offsets in the tensor but set trailing
187  // elements for smaller split-groups to -1. Computing the max size of the
188  // split groups needs using collectiveProcessGroupSize (which needs the
189  // GridOp)
190  Value resOffsets;
191  if (adaptor.getStaticShardedDimsOffsets().empty()) {
192  resOffsets = tensor::EmptyOp::create(rewriter, loc,
193  std::array<int64_t, 2>{0, 0}, i64);
194  } else {
195  SymbolTableCollection symbolTableCollection;
196  auto gridOp = getGrid(op, symbolTableCollection);
197  int64_t maxSplitSize = 0;
198  for (auto axes : splitAxes) {
199  int64_t splitSize =
200  collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
201  assert(splitSize != ShapedType::kDynamic);
202  maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
203  }
204  assert(maxSplitSize);
205  ++maxSplitSize; // add one for the total size
206 
207  resOffsets = tensor::EmptyOp::create(
208  rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
209  Value zero = arith::ConstantOp::create(
210  rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
211  resOffsets =
212  linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
213  SmallVector<Value> offsets =
214  getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
215  adaptor.getDynamicShardedDimsOffsets());
216  int64_t curr = 0;
217  for (auto [i, axes] : llvm::enumerate(splitAxes)) {
218  int64_t splitSize =
219  collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
220  assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
221  ++splitSize; // add one for the total size
222  ArrayRef<Value> values(&offsets[curr], splitSize);
223  Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
224  std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
225  std::array<int64_t, 2> sizes = {1, splitSize};
226  resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
227  resOffsets, empty, empty,
228  empty, offs, sizes, strides);
229  curr += splitSize;
230  }
231  }
232 
233  // return a tuple of tensors as defined by type converter
234  SmallVector<Type> resTypes;
235  if (failed(getTypeConverter()->convertType(op.getResult().getType(),
236  resTypes)))
237  return failure();
238 
239  resSplitAxes =
240  tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
241  resHaloSizes =
242  tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
243  resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
244 
245  rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
246  op, TupleType::get(op.getContext(), resTypes),
247  ValueRange{resSplitAxes, resHaloSizes, resOffsets});
248 
249  return success();
250  }
251 };
252 
253 struct ConvertProcessMultiIndexOp
254  : public OpConversionPattern<ProcessMultiIndexOp> {
256 
257  LogicalResult
258  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
259  ConversionPatternRewriter &rewriter) const override {
260 
261  // Currently converts its linear index to a multi-dimensional index.
262 
263  SymbolTableCollection symbolTableCollection;
264  Location loc = op.getLoc();
265  auto gridOp = getGrid(op, symbolTableCollection);
266  // For now we only support static grid shapes
267  if (ShapedType::isDynamicShape(gridOp.getShape()))
268  return failure();
269 
270  SmallVector<Value> dims;
271  llvm::transform(
272  gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
273  return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
274  });
275  Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
276  auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
277 
278  // optionally extract subset of grid axes
279  auto axes = adaptor.getAxes();
280  if (!axes.empty()) {
281  SmallVector<Value> subIndex;
282  for (auto axis : axes) {
283  subIndex.emplace_back(mIdx[axis]);
284  }
285  mIdx = std::move(subIndex);
286  }
287 
288  rewriter.replaceOp(op, mIdx);
289  return success();
290  }
291 };
292 
293 class ConvertProcessLinearIndexOp
294  : public OpConversionPattern<ProcessLinearIndexOp> {
295 
296 public:
298 
299  LogicalResult
300  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
301  ConversionPatternRewriter &rewriter) const override {
302  // Create mpi::CommRankOp
303  Location loc = op.getLoc();
304  auto *ctx = op.getContext();
305  Value commWorld =
306  mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
307  auto rank = mpi::CommRankOp::create(
308  rewriter, loc,
309  TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
310  commWorld)
311  .getRank();
312  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
313  rank);
314  return success();
315  }
316 };
317 
318 struct ConvertNeighborsLinearIndicesOp
319  : public OpConversionPattern<NeighborsLinearIndicesOp> {
321 
322  LogicalResult
323  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
324  ConversionPatternRewriter &rewriter) const override {
325 
326  // Computes the neighbors indices along a split axis by simply
327  // adding/subtracting 1 to the current index in that dimension.
328  // Assigns -1 if neighbor is out of bounds.
329 
330  auto axes = adaptor.getSplitAxes();
331  // For now only single axis sharding is supported
332  if (axes.size() != 1)
333  return failure();
334 
335  Location loc = op.getLoc();
336  SymbolTableCollection symbolTableCollection;
337  auto gridOp = getGrid(op, symbolTableCollection);
338  auto mIdx = adaptor.getDevice();
339  auto orgIdx = mIdx[axes[0]];
340  SmallVector<Value> dims;
341  llvm::transform(
342  gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
343  return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
344  });
345  Value dimSz = dims[axes[0]];
346  Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
347  Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
348  Value atBorder =
349  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
350  arith::ConstantIndexOp::create(rewriter, loc, 0));
351  auto down = scf::IfOp::create(
352  rewriter, loc, atBorder,
353  [&](OpBuilder &builder, Location loc) {
354  scf::YieldOp::create(builder, loc, minus1);
355  },
356  [&](OpBuilder &builder, Location loc) {
357  SmallVector<Value> tmp = mIdx;
358  tmp[axes[0]] =
359  arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
360  .getResult();
361  scf::YieldOp::create(builder, loc,
362  multiToLinearIndex(loc, rewriter, tmp, dims));
363  });
364  atBorder = arith::CmpIOp::create(
365  rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
366  arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
367  auto up = scf::IfOp::create(
368  rewriter, loc, atBorder,
369  [&](OpBuilder &builder, Location loc) {
370  scf::YieldOp::create(builder, loc, minus1);
371  },
372  [&](OpBuilder &builder, Location loc) {
373  SmallVector<Value> tmp = mIdx;
374  tmp[axes[0]] =
375  arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
376  scf::YieldOp::create(builder, loc,
377  multiToLinearIndex(loc, rewriter, tmp, dims));
378  });
379  rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
380  return success();
381  }
382 };
383 
384 struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
386 
387  LogicalResult
388  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
389  ConversionPatternRewriter &rewriter) const override {
390  auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
391  if (!sharding) {
392  return op->emitError()
393  << "Expected ShardingOp as defining op for sharding"
394  << " but found " << adaptor.getSharding()[0].getDefiningOp();
395  }
396 
397  // Compute the sharded shape by applying the sharding to the input shape.
398  // If shardedDimsOffsets is not defined in the sharding, the shard shape is
399  // computed by dividing the dimension size by the number of shards in that
400  // dimension (which is given by the size of the grid axes provided in
401  // split-axes). Odd elements get distributed to trailing shards. If a
402  // shardedDimsOffsets is provided, the shard shape is computed by
403  // subtracting the offset of the current shard from the offset of the next
404  // shard.
405 
406  Location loc = op.getLoc();
407  Type index = rewriter.getIndexType();
408 
409  // This is a 1:N conversion because the sharding op is a 1:3 conversion.
410  // The operands in the adaptor are a vector<ValeRange>. For dims and device
411  // we have a 1:1 conversion.
412  // For simpler access fill a vector with the dynamic dims.
413  SmallVector<Value> dynDims, dynDevice;
414  for (auto dim : adaptor.getDimsDynamic()) {
415  // type conversion should be 1:1 for ints
416  dynDims.emplace_back(llvm::getSingleElement(dim));
417  }
418  // same for device
419  for (auto device : adaptor.getDeviceDynamic()) {
420  dynDevice.emplace_back(llvm::getSingleElement(device));
421  }
422 
423  // To keep the code simple, convert dims/device to values when they are
424  // attributes. Count on canonicalization to fold static values.
425  SmallVector<Value> shape =
426  getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
427  SmallVector<Value> multiIdx =
428  getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
429 
430  // Get the GridOp, the grid shape is needed to compute the sharded shape.
431  SymbolTableCollection symbolTableCollection;
432  auto gridOp = getGrid(sharding, symbolTableCollection);
433  // For now we only support static grid shapes
434  if (ShapedType::isDynamicShape(gridOp.getShape()))
435  return failure();
436 
437  auto splitAxes = sharding.getSplitAxes().getAxes();
438  // shardedDimsOffsets are optional and might be Values (not attributes).
439  // Also, the shardId might be dynamic which means the position in the
440  // shardedDimsOffsets is not statically known. Create a tensor of the
441  // shardedDimsOffsets and later extract the offsets for computing the
442  // local shard-size.
443  Value shardedDimsOffs;
444  {
446  rewriter, loc, sharding.getStaticShardedDimsOffsets(),
447  sharding.getDynamicShardedDimsOffsets(), index);
448  if (!tmp.empty())
449  shardedDimsOffs = tensor::FromElementsOp::create(
450  rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
451  tmp);
452  }
453 
454  // With static grid shape the sizes of the split axes are known.
455  // Hence the start/pos for each split axes in shardDimsOffsets can be
456  // computed statically.
457  int64_t pos = 0;
459  Value zero =
460  arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
461  Value one =
462  arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
463 
464  // Iterate over the dimensions of the tensor shape, get their split Axes,
465  // and compute the sharded shape.
466  for (auto [i, dim] : llvm::enumerate(shape)) {
467  // Trailing dimensions might not be annotated.
468  if (i < splitAxes.size() && !splitAxes[i].empty()) {
469  auto axes = splitAxes[i];
470  // The current dimension might not be sharded.
471  // Create a value from the static position in shardDimsOffsets.
472  Value posVal = arith::ConstantOp::create(rewriter, loc,
473  rewriter.getIndexAttr(pos));
474  // Get the index of the local shard in the grid axis.
475  Value idx = multiIdx[axes[0]];
476  auto numShards =
477  collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
478  if (shardedDimsOffs) {
479  // If sharded dims offsets are provided, use them to compute the
480  // sharded shape.
481  if (axes.size() > 1) {
482  return op->emitError() << "Only single axis sharding is "
483  << "supported for each dimension.";
484  }
485  idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
486  // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
487  Value off =
488  tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
489  idx = arith::AddIOp::create(rewriter, loc, idx, one);
490  Value nextOff =
491  tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
492  Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
493  shardShape.emplace_back(sz);
494  } else {
495  Value numShardsVal = arith::ConstantOp::create(
496  rewriter, loc, rewriter.getIndexAttr(numShards));
497  // Compute shard dim size by distributing odd elements to trailing
498  // shards:
499  // sz = dim / numShards
500  // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
501  Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
502  Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
503  sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
504  auto cond = arith::CmpIOp::create(
505  rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
506  Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
507  sz = arith::AddIOp::create(rewriter, loc, sz, odd);
508  shardShape.emplace_back(sz);
509  }
510  pos += numShards + 1; // add one for the total size.
511  } // else no sharding if split axis is empty or no split axis
512  // If no size was added -> no sharding in this dimension.
513  if (shardShape.size() <= i)
514  shardShape.emplace_back(dim);
515  }
516  assert(shardShape.size() == shape.size());
517  rewriter.replaceOp(op, shardShape);
518  return success();
519  }
520 };
521 
522 static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523  auto *ctx = kind.getContext();
524  auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
525  return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
526  };
527 
528  switch (kind.getValue()) {
529  case ReductionKind::Sum:
530  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
531  case ReductionKind::Product:
532  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
533  case ReductionKind::Min:
534  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
535  case ReductionKind::Max:
536  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
537  case ReductionKind::BitwiseAnd:
538  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
539  case ReductionKind::BitwiseOr:
540  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
541  case ReductionKind::BitwiseXor:
542  return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
543  default:
544  llvm_unreachable("Unknown/unsupported reduction kind");
545  }
546 }
547 
548 struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
550 
551  LogicalResult
552  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  SymbolTableCollection symbolTableCollection;
555  auto grid = adaptor.getGrid();
556  mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
557  if (!gridOp)
558  return op->emitError() << "No grid found for AllReduceOp";
559  if (ShapedType::isDynamicShape(gridOp.getShape()))
560  return op->emitError()
561  << "Dynamic grid shape not supported in AllReduceOp";
562 
563  ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
564  Value input = adaptor.getInput();
565  auto inputShape = cast<ShapedType>(input.getType()).getShape();
566 
567  // If the source is a memref, cast it to a tensor.
568  if (isa<RankedTensorType>(input.getType())) {
569  auto memrefType = MemRefType::get(
570  inputShape, cast<ShapedType>(input.getType()).getElementType());
571  input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
572  }
573  MemRefType inType = cast<MemRefType>(input.getType());
574 
575  // Get the actual shape to allocate the buffer.
576  SmallVector<OpFoldResult> shape(inType.getRank());
577  for (auto i = 0; i < inType.getRank(); ++i) {
578  auto s = inputShape[i];
579  if (ShapedType::isDynamic(s))
580  shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
581  else
582  shape[i] = iBuilder.getIndexAttr(s);
583  }
584 
585  // Allocate buffer and copy input to buffer.
586  Value buffer = memref::AllocOp::create(
587  iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
588  linalg::CopyOp::create(iBuilder, input, buffer);
589 
590  // Get an MPI_Comm_split for the AllReduce operation.
591  // The color is the linear index of the process in the grid along the
592  // non-reduced axes. The key is the linear index of the process in the grid
593  // along the reduced axes.
594  SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
595  iBuilder.getIndexType());
596  SmallVector<Value> myMultiIndex =
597  ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
598  .getResult();
599  Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
600  SmallVector<Value> multiKey(myMultiIndex.size(), zero);
601 
602  auto redAxes = adaptor.getGridAxes();
603  for (auto axis : redAxes) {
604  multiKey[axis] = myMultiIndex[axis];
605  myMultiIndex[axis] = zero;
606  }
607 
608  Value color =
609  createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
610  color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
611  Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
612  key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
613 
614  // Finally split the communicator
615  auto commType = mpi::CommType::get(op->getContext());
616  Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
617  auto comm =
618  mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
619  .getNewcomm();
620 
621  Value buffer1d = buffer;
622  // Collapse shape to 1d if needed
623  if (inType.getRank() > 1) {
624  ReassociationIndices reassociation(inType.getRank());
625  std::iota(reassociation.begin(), reassociation.end(), 0);
626  buffer1d = memref::CollapseShapeOp::create(
627  iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
628  }
629 
630  // Create the MPI AllReduce operation.
631  mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
632  getMPIReductionOp(adaptor.getReductionAttr()),
633  comm);
634 
635  // If the destination is a memref, cast it to a tensor
636  if (isa<RankedTensorType>(op.getType()))
637  buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
638  true);
639 
640  rewriter.replaceOp(op, buffer);
641  return success();
642  }
643 };
644 
645 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
647 
648  LogicalResult
649  matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
650  ConversionPatternRewriter &rewriter) const override {
651 
652  // The input/output memref is assumed to be in C memory order.
653  // Halos are exchanged as 2 blocks per dimension (one for each side: down
654  // and up). For each haloed dimension `d`, the exchanged blocks are
655  // expressed as multi-dimensional subviews. The subviews include potential
656  // halos of higher dimensions `dh > d`, no halos for the lower dimensions
657  // `dl < d` and for dimension `d` the currently exchanged halo only.
658  // By iterating form higher to lower dimensions this also updates the halos
659  // in the 'corners'.
660  // memref.subview is used to read and write the halo data from and to the
661  // local data. Because subviews and halos can have mixed dynamic and static
662  // shapes, OpFoldResults are used whenever possible.
663 
664  auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
665  adaptor.getHaloSizes(), rewriter);
666  if (haloSizes.empty()) {
667  // no halos -> nothing to do
668  rewriter.replaceOp(op, adaptor.getDestination());
669  return success();
670  }
671 
672  SymbolTableCollection symbolTableCollection;
673  Location loc = op.getLoc();
674 
675  // convert a OpFoldResult into a Value
676  auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
677  if (auto value = dyn_cast<Value>(v))
678  return value;
679  return arith::ConstantOp::create(
680  rewriter, loc,
681  rewriter.getIndexAttr(
682  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
683  };
684 
685  auto dest = adaptor.getDestination();
686  auto dstShape = cast<ShapedType>(dest.getType()).getShape();
687  Value array = dest;
688  if (isa<RankedTensorType>(array.getType())) {
689  // If the destination is a memref, we need to cast it to a tensor
690  auto mmemrefType = MemRefType::get(
691  dstShape, cast<ShapedType>(array.getType()).getElementType());
692  array =
693  bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
694  }
695  auto rank = cast<ShapedType>(array.getType()).getRank();
696  auto opSplitAxes = adaptor.getSplitAxes().getAxes();
697  auto grid = adaptor.getGrid();
698  auto gridOp = getGrid(op, symbolTableCollection);
699  // subviews need Index values
700  for (auto &sz : haloSizes) {
701  if (auto value = dyn_cast<Value>(sz))
702  sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
703  value)
704  .getResult();
705  }
706 
707  // most of the offset/size/stride data is the same for all dims
708  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
709  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
710  SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
711  auto currHaloDim = -1; // halo sizes are provided for split dimensions only
712  // we need the actual shape to compute offsets and sizes
713  for (auto i = 0; i < rank; ++i) {
714  auto s = dstShape[i];
715  if (ShapedType::isDynamic(s))
716  shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
717  else
718  shape[i] = rewriter.getIndexAttr(s);
719 
720  if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
721  ++currHaloDim;
722  // the offsets for lower dim sstarts after their down halo
723  offsets[i] = haloSizes[currHaloDim * 2];
724 
725  // prepare shape and offsets of highest dim's halo exchange
726  Value _haloSz = arith::AddIOp::create(
727  rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
728  toValue(haloSizes[currHaloDim * 2 + 1]));
729  // the halo shape of lower dims exlude the halos
730  dimSizes[i] =
731  arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
732  .getResult();
733  } else {
734  dimSizes[i] = shape[i];
735  }
736  }
737 
738  auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
739  auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
740  auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
741  auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
742 
743  SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
744  rewriter.getIndexType());
745  auto myMultiIndex =
746  ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
747  .getResult();
748  // traverse all split axes from high to low dim
749  for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
750  auto splitAxes = opSplitAxes[dim];
751  if (splitAxes.empty())
752  continue;
753  assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
754  // Get the linearized ids of the neighbors (down and up) for the
755  // given split
756  auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
757  myMultiIndex, splitAxes)
758  .getResults();
759  // MPI operates on i32...
760  Value neighbourIDs[2] = {
761  arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
762  tmp[0]),
763  arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
764  tmp[1])};
765 
766  auto lowerRecvOffset = rewriter.getIndexAttr(0);
767  auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
768  auto upperRecvOffset =
769  arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
770  toValue(haloSizes[currHaloDim * 2 + 1]));
771  auto upperSendOffset = arith::SubIOp::create(
772  rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
773 
774  Value commWorld = mpi::CommWorldOp::create(
775  rewriter, loc, mpi::CommType::get(op->getContext()));
776 
777  // Make sure we send/recv in a way that does not lead to a dead-lock.
778  // The current approach is by far not optimal, this should be at least
779  // be a red-black pattern or using MPI_sendrecv.
780  // Also, buffers should be re-used.
781  // Still using temporary contiguous buffers for MPI communication...
782  // Still yielding a "serialized" communication pattern...
783  auto genSendRecv = [&](bool upperHalo) {
784  auto orgOffset = offsets[dim];
785  dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
786  : haloSizes[currHaloDim * 2];
787  // Check if we need to send and/or receive
788  // Processes on the grid borders have only one neighbor
789  auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
790  auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
791  auto hasFrom = arith::CmpIOp::create(
792  rewriter, loc, arith::CmpIPredicate::sge, from, zero);
793  auto hasTo = arith::CmpIOp::create(rewriter, loc,
794  arith::CmpIPredicate::sge, to, zero);
795  auto buffer = memref::AllocOp::create(
796  rewriter, loc, dimSizes,
797  cast<ShapedType>(array.getType()).getElementType());
798  // if has neighbor: copy halo data from array to buffer and send
799  scf::IfOp::create(
800  rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
801  offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
802  : OpFoldResult(upperSendOffset);
803  auto subview = memref::SubViewOp::create(
804  builder, loc, array, offsets, dimSizes, strides);
805  memref::CopyOp::create(builder, loc, subview, buffer);
806  mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
807  commWorld);
808  scf::YieldOp::create(builder, loc);
809  });
810  // if has neighbor: receive halo data into buffer and copy to array
811  scf::IfOp::create(
812  rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
813  offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
814  : OpFoldResult(lowerRecvOffset);
815  mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
816  commWorld);
817  auto subview = memref::SubViewOp::create(
818  builder, loc, array, offsets, dimSizes, strides);
819  memref::CopyOp::create(builder, loc, buffer, subview);
820  scf::YieldOp::create(builder, loc);
821  });
822  memref::DeallocOp::create(rewriter, loc, buffer);
823  offsets[dim] = orgOffset;
824  };
825 
826  auto doSendRecv = [&](int upOrDown) {
827  OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
828  Value haloSz = dyn_cast<Value>(v);
829  if (!haloSz)
830  haloSz = arith::ConstantOp::create(
831  rewriter, loc,
832  rewriter.getI32IntegerAttr(
833  cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
834  auto hasSize = arith::CmpIOp::create(
835  rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
836  scf::IfOp::create(rewriter, loc, hasSize,
837  [&](OpBuilder &builder, Location loc) {
838  genSendRecv(upOrDown > 0);
839  scf::YieldOp::create(builder, loc);
840  });
841  };
842 
843  doSendRecv(0);
844  doSendRecv(1);
845 
846  // the shape for lower dims include higher dims' halos
847  dimSizes[dim] = shape[dim];
848  // -> the offset for higher dims is always 0
849  offsets[dim] = rewriter.getIndexAttr(0);
850  // on to next halo
851  --currHaloDim;
852  }
853 
854  if (isa<MemRefType>(op.getResult().getType())) {
855  rewriter.replaceOp(op, array);
856  } else {
857  assert(isa<RankedTensorType>(op.getResult().getType()));
858  rewriter.replaceOp(op, bufferization::ToTensorOp::create(
859  rewriter, loc, op.getResult().getType(), array,
860  /*restrict=*/true, /*writable=*/true));
861  }
862  return success();
863  }
864 };
865 
866 struct ConvertShardToMPIPass
867  : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
868  using Base::Base;
869 
870  /// Run the dialect converter on the module.
871  void runOnOperation() override {
872  auto *ctxt = &getContext();
874  ConversionTarget target(getContext());
875 
876  // Define a type converter to convert shard::ShardingType,
877  // mostly for use in return operations.
878  TypeConverter typeConverter;
879  typeConverter.addConversion([](Type type) { return type; });
880 
881  // convert shard::ShardingType to a tuple of RankedTensorTypes
882  typeConverter.addConversion(
883  [](ShardingType type,
884  SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
885  auto i16 = IntegerType::get(type.getContext(), 16);
886  auto i64 = IntegerType::get(type.getContext(), 64);
887  std::array<int64_t, 2> shp = {ShapedType::kDynamic,
888  ShapedType::kDynamic};
889  results.emplace_back(RankedTensorType::get(shp, i16));
890  results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
891  results.emplace_back(RankedTensorType::get(shp, i64));
892  return success();
893  });
894 
895  // To 'extract' components, a UnrealizedConversionCastOp is expected
896  // to define the input
897  typeConverter.addTargetMaterialization(
898  [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
899  Location loc) {
900  // Expecting a single input.
901  if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
902  return SmallVector<Value>();
903  auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
904  // Expecting an UnrealizedConversionCastOp.
905  if (!castOp)
906  return SmallVector<Value>();
907  // Fill a vector with elements of the tuple/castOp.
908  SmallVector<Value> results;
909  for (auto oprnd : castOp.getInputs()) {
910  if (!isa<RankedTensorType>(oprnd.getType()))
911  return SmallVector<Value>();
912  results.emplace_back(oprnd);
913  }
914  return results;
915  });
916 
917  // No shard dialect should left after conversion...
918  target.addIllegalDialect<shard::ShardDialect>();
919  // ...except the global GridOp. GridShapeOp which will get folded later.
920  target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
921  // Allow all the stuff that our patterns will convert to
922  target.addLegalDialect<
923  BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
924  tensor::TensorDialect, bufferization::BufferizationDialect,
925  linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
926  // Make sure the function signature, calls etc. are legal
927  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
928  return typeConverter.isSignatureLegal(op.getFunctionType());
929  });
930  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
931  [&](Operation *op) { return typeConverter.isLegal(op); });
932 
933  patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
934  ConvertProcessMultiIndexOp, ConvertGetShardingOp,
935  ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
936  ConvertProcessLinearIndexOp>(typeConverter, ctxt);
937 
938  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
939  patterns, typeConverter);
942 
943  (void)applyPartialConversion(getOperation(), target, std::move(patterns));
944 
945  // Folding patterns cannot be mixed with conversion patterns -> extra pass.
946  patterns.clear();
947  SymbolTableCollection symbolTableCollection;
948  mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
949  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
950  }
951 };
952 
953 } // namespace
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition: ShardOps.cpp:214
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getI16Type()
Definition: Builders.cpp:60
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
IndexType getIndexType()
Definition: Builders.cpp:50
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:341
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition: Location.h:86
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2779
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
shard::GridOp GridOp
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
Definition: Transforms.cpp:228
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition: ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition: ShardOps.h:146
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition: ShardOps.h:121
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.