MLIR 23.0.0git
ShardToMPI.cpp
Go to the documentation of this file.
1//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Shard communication ops to MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
15
33#include "mlir/IR/Builders.h"
37#include "mlir/IR/SymbolTable.h"
40
41#define DEBUG_TYPE "shard-to-mpi"
42
43namespace mlir {
44#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49using namespace shard;
50
51namespace {
52/// Converts a vector of OpFoldResults (ints) into vector of Values of the
53/// provided type.
56 ValueRange dynamics,
57 Type type = Type()) {
58 SmallVector<Value> values;
59 auto dyn = dynamics.begin();
60 Type i64 = b.getI64Type();
61 if (!type)
62 type = i64;
63 assert((i64 == type || b.getIndexType() == type) &&
64 "expected an i64 or an intex type");
65 for (auto s : statics) {
66 if (s == ShapedType::kDynamic) {
67 values.emplace_back(*(dyn++));
68 } else {
69 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
70 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
71 }
72 }
73 return values;
74}
75
76/// Create operations converting a linear index to a multi-dimensional index.
77[[maybe_unused]] static SmallVector<Value>
78linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex,
79 ValueRange dimensions) {
80 int n = dimensions.size();
81 SmallVector<Value> multiIndex(n);
82
83 for (int i = n - 1; i >= 0; --i) {
84 multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
85 if (i > 0)
86 linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
87 }
88
89 return multiIndex;
90}
91
92/// Create operations converting a multi-dimensional index to a linear index.
93Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
94 ValueRange dimensions) {
95
96 Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
97 Value stride = arith::ConstantIndexOp::create(b, loc, 1);
98
99 for (int i = multiIndex.size() - 1; i >= 0; --i) {
100 Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
101 linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
102 stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
103 }
104
105 return linearIndex;
106}
107
108/// Replace GetShardingOp with related/dependent ShardingOp.
109struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
110 using OpConversionPattern::OpConversionPattern;
111
112 LogicalResult
113 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter) const override {
115 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
116 if (!shardOp)
117 return failure();
118 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
119 if (!shardingOp)
120 return failure();
121
122 rewriter.replaceOp(op, shardingOp.getResult());
123 return success();
124 }
125};
126
127/// Convert a sharding op to a tuple of tensors of its components
128/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
129/// as defined by type converter.
130struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
131 using OpConversionPattern::OpConversionPattern;
132
133 LogicalResult
134 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 auto splitAxes = op.getSplitAxes().getAxes();
137 int64_t maxNAxes = 0;
138 for (auto axes : splitAxes)
139 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
140
141 // To hold the split axes, create empty 2d tensor with shape
142 // {splitAxes.size(), max-size-of-split-groups}.
143 // Set trailing elements for smaller split-groups to -1.
144 Location loc = op.getLoc();
145 auto i16 = rewriter.getI16Type();
146 auto i64 = rewriter.getI64Type();
147 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
148 maxNAxes};
149 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
150 auto attr = IntegerAttr::get(i16, -1);
151 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
152 resSplitAxes =
153 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
154 .getResult(0);
155
156 // explicitly write values into tensor row by row
157 std::array<int64_t, 2> strides = {1, 1};
158 int64_t nSplits = 0;
159 ValueRange empty = {};
160 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
161 int64_t size = axes.size();
162 if (size > 0)
163 ++nSplits;
164 std::array<int64_t, 2> offs = {(int64_t)i, 0};
165 std::array<int64_t, 2> sizes = {1, size};
166 auto tensorType = RankedTensorType::get({size}, i16);
167 auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
168 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
169 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
170 resSplitAxes, empty, empty,
171 empty, offs, sizes, strides);
172 }
173
174 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
175 // Store the halo sizes in the tensor.
176 SmallVector<Value> haloSizes =
177 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
178 adaptor.getDynamicHaloSizes());
179 auto type = RankedTensorType::get({nSplits, 2}, i64);
180 Value resHaloSizes =
181 haloSizes.empty()
182 ? tensor::EmptyOp::create(rewriter, loc,
183 std::array<int64_t, 2>{0, 0}, i64)
184 .getResult()
185 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
186 .getResult();
187
188 // To hold sharded dims offsets, create Tensor with shape {nSplits,
189 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
190 // elements for smaller split-groups to -1. Computing the max size of the
191 // split groups needs using collectiveProcessGroupSize (which needs the
192 // GridOp)
193 Value resOffsets;
194 if (adaptor.getStaticShardedDimsOffsets().empty()) {
195 resOffsets = tensor::EmptyOp::create(rewriter, loc,
196 std::array<int64_t, 2>{0, 0}, i64);
197 } else {
198 SymbolTableCollection symbolTableCollection;
199 auto gridOp = getGrid(op, symbolTableCollection);
200 int64_t maxSplitSize = 0;
201 for (auto axes : splitAxes) {
202 int64_t splitSize =
203 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
204 assert(splitSize != ShapedType::kDynamic);
205 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
206 }
207 assert(maxSplitSize);
208 ++maxSplitSize; // add one for the total size
209
210 resOffsets = tensor::EmptyOp::create(
211 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
212 Value zero = arith::ConstantOp::create(
213 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
214 resOffsets =
215 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
216 SmallVector<Value> offsets =
217 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
218 adaptor.getDynamicShardedDimsOffsets());
219 int64_t curr = 0;
220 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
221 int64_t splitSize =
222 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
223 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
224 ++splitSize; // add one for the total size
225 ArrayRef<Value> values(&offsets[curr], splitSize);
226 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
227 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
228 std::array<int64_t, 2> sizes = {1, splitSize};
229 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
230 resOffsets, empty, empty,
231 empty, offs, sizes, strides);
232 curr += splitSize;
233 }
234 }
235
236 // return a tuple of tensors as defined by type converter
237 SmallVector<Type> resTypes;
238 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
239 resTypes)))
240 return failure();
241
242 resSplitAxes =
243 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
244 resHaloSizes =
245 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
246 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
247
248 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
249 op, TupleType::get(op.getContext(), resTypes),
250 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
251
252 return success();
253 }
254};
255
256class ConvertProcessLinearIndexOp
257 : public OpConversionPattern<ProcessLinearIndexOp> {
258
259public:
260 using OpConversionPattern::OpConversionPattern;
261
262 LogicalResult
263 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 // Create mpi::CommRankOp
266 Location loc = op.getLoc();
267 auto *ctx = op.getContext();
268 Value commWorld =
269 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
270 auto rank = mpi::CommRankOp::create(
271 rewriter, loc,
272 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
273 commWorld)
274 .getRank();
275 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
276 rank);
277 return success();
278 }
279};
280
281struct ConvertNeighborsLinearIndicesOp
282 : public OpConversionPattern<NeighborsLinearIndicesOp> {
283 using OpConversionPattern::OpConversionPattern;
284
285 LogicalResult
286 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const override {
288
289 // Computes the neighbors indices along a split axis by simply
290 // adding/subtracting 1 to the current index in that dimension.
291 // Assigns -1 if neighbor is out of bounds.
292
293 auto axes = adaptor.getSplitAxes();
294 // For now only single axis sharding is supported
295 if (axes.size() != 1)
296 return failure();
297
298 Location loc = op.getLoc();
299 SymbolTableCollection symbolTableCollection;
300 auto gridOp = getGrid(op, symbolTableCollection);
301 auto mIdx = adaptor.getDevice();
302 auto orgIdx = mIdx[axes[0]];
303 SmallVector<Value> dims;
304 llvm::transform(
305 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
306 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
307 });
308 Value dimSz = dims[axes[0]];
309 Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
310 Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
311 Value atBorder =
312 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
313 arith::ConstantIndexOp::create(rewriter, loc, 0));
314 auto down = scf::IfOp::create(
315 rewriter, loc, atBorder,
316 [&](OpBuilder &builder, Location loc) {
317 scf::YieldOp::create(builder, loc, minus1);
318 },
319 [&](OpBuilder &builder, Location loc) {
320 SmallVector<Value> tmp = mIdx;
321 tmp[axes[0]] =
322 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
323 .getResult();
324 scf::YieldOp::create(builder, loc,
325 multiToLinearIndex(loc, rewriter, tmp, dims));
326 });
327 atBorder = arith::CmpIOp::create(
328 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
329 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
330 auto up = scf::IfOp::create(
331 rewriter, loc, atBorder,
332 [&](OpBuilder &builder, Location loc) {
333 scf::YieldOp::create(builder, loc, minus1);
334 },
335 [&](OpBuilder &builder, Location loc) {
336 SmallVector<Value> tmp = mIdx;
337 tmp[axes[0]] =
338 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
339 scf::YieldOp::create(builder, loc,
340 multiToLinearIndex(loc, rewriter, tmp, dims));
341 });
342 rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
343 return success();
344 }
345};
346
347struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
348 using OpConversionPattern::OpConversionPattern;
349
350 LogicalResult
351 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const override {
353 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
354 if (!sharding) {
355 return op->emitError()
356 << "Expected ShardingOp as defining op for sharding"
357 << " but found " << adaptor.getSharding()[0].getDefiningOp();
358 }
359
360 // Compute the sharded shape by applying the sharding to the input shape.
361 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
362 // computed by dividing the dimension size by the number of shards in that
363 // dimension (which is given by the size of the grid axes provided in
364 // split-axes). Odd elements get distributed to trailing shards. If a
365 // shardedDimsOffsets is provided, the shard shape is computed by
366 // subtracting the offset of the current shard from the offset of the next
367 // shard.
368
369 Location loc = op.getLoc();
370 Type index = rewriter.getIndexType();
371
372 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
373 // The operands in the adaptor are a vector<ValeRange>. For dims and device
374 // we have a 1:1 conversion.
375 // For simpler access fill a vector with the dynamic dims.
376 SmallVector<Value> dynDims, dynDevice;
377 for (auto dim : adaptor.getDimsDynamic()) {
378 // type conversion should be 1:1 for ints
379 dynDims.emplace_back(llvm::getSingleElement(dim));
380 }
381 // same for device
382 for (auto device : adaptor.getDeviceDynamic()) {
383 dynDevice.emplace_back(llvm::getSingleElement(device));
384 }
385
386 // To keep the code simple, convert dims/device to values when they are
387 // attributes. Count on canonicalization to fold static values.
388 SmallVector<Value> shape =
389 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
390 SmallVector<Value> multiIdx =
391 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
392
393 // Get the GridOp, the grid shape is needed to compute the sharded shape.
394 SymbolTableCollection symbolTableCollection;
395 auto gridOp = getGrid(sharding, symbolTableCollection);
396 // For now we only support static grid shapes
397 if (ShapedType::isDynamicShape(gridOp.getShape()))
398 return failure();
399
400 auto splitAxes = sharding.getSplitAxes().getAxes();
401 // shardedDimsOffsets are optional and might be Values (not attributes).
402 // Also, the shardId might be dynamic which means the position in the
403 // shardedDimsOffsets is not statically known. Create a tensor of the
404 // shardedDimsOffsets and later extract the offsets for computing the
405 // local shard-size.
406 Value shardedDimsOffs;
407 {
408 SmallVector<Value> tmp = getMixedAsValues(
409 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
410 sharding.getDynamicShardedDimsOffsets(), index);
411 if (!tmp.empty())
412 shardedDimsOffs = tensor::FromElementsOp::create(
413 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
414 tmp);
415 }
416
417 // With static grid shape the sizes of the split axes are known.
418 // Hence the start/pos for each split axes in shardDimsOffsets can be
419 // computed statically.
420 int64_t pos = 0;
421 SmallVector<Value> shardShape;
422 Value zero =
423 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
424 Value one =
425 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
426
427 // Iterate over the dimensions of the tensor shape, get their split Axes,
428 // and compute the sharded shape.
429 for (auto [i, dim] : llvm::enumerate(shape)) {
430 // Trailing dimensions might not be annotated.
431 if (i < splitAxes.size() && !splitAxes[i].empty()) {
432 auto axes = splitAxes[i];
433 // The current dimension might not be sharded.
434 // Create a value from the static position in shardDimsOffsets.
435 Value posVal = arith::ConstantOp::create(rewriter, loc,
436 rewriter.getIndexAttr(pos));
437 // Get the index of the local shard in the grid axis.
438 Value idx = multiIdx[axes[0]];
439 auto numShards =
440 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
441 if (shardedDimsOffs) {
442 // If sharded dims offsets are provided, use them to compute the
443 // sharded shape.
444 if (axes.size() > 1) {
445 return op->emitError() << "Only single axis sharding is "
446 << "supported for each dimension.";
447 }
448 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
449 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
450 Value off =
451 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
452 idx = arith::AddIOp::create(rewriter, loc, idx, one);
453 Value nextOff =
454 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
455 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
456 shardShape.emplace_back(sz);
457 } else {
458 Value numShardsVal = arith::ConstantOp::create(
459 rewriter, loc, rewriter.getIndexAttr(numShards));
460 // Compute shard dim size by distributing odd elements to trailing
461 // shards:
462 // sz = dim / numShards
463 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
464 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
465 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
466 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
467 auto cond = arith::CmpIOp::create(
468 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
469 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
470 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
471 shardShape.emplace_back(sz);
472 }
473 pos += numShards + 1; // add one for the total size.
474 } // else no sharding if split axis is empty or no split axis
475 // If no size was added -> no sharding in this dimension.
476 if (shardShape.size() <= i)
477 shardShape.emplace_back(dim);
478 }
479 assert(shardShape.size() == shape.size());
480 rewriter.replaceOp(op, shardShape);
481 return success();
482 }
483};
484
485static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
486 auto *ctx = kind.getContext();
487 auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
488 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
489 };
490
491 switch (kind.getValue()) {
492 case ReductionKind::Sum:
493 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
494 case ReductionKind::Product:
495 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
496 case ReductionKind::Min:
497 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
498 case ReductionKind::Max:
499 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
500 case ReductionKind::BitwiseAnd:
501 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
502 case ReductionKind::BitwiseOr:
503 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
504 case ReductionKind::BitwiseXor:
505 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
506 default:
507 llvm_unreachable("Unknown/unsupported reduction kind");
508 }
509}
510
511template <typename CommOp>
512struct CommOpPattern : public OpConversionPattern<CommOp> {
513 using OpConversionPattern<CommOp>::OpConversionPattern;
514
515 MemRefType getMemrefType(ShapedType tensorType) const {
516 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
517 }
518
519 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
520 auto itype = input.getType();
521 // If the source is a memref, cast it to a tensor.
522 if (isa<RankedTensorType>(itype)) {
523 auto memrefType = getMemrefType(cast<ShapedType>(itype));
524 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
525 } else {
526 assert(isa<MemRefType>(itype) &&
527 "expected input to be of MemRefType or TensorType");
528 }
529 return input;
530 }
531
532 FailureOr<GridOp> checkGrid(CommOp op,
533 SymbolTableCollection &symbolTableCollection,
534 bool allowDynamic = false) const {
535 GridOp gridOp = getGrid(op, symbolTableCollection);
536 if (!gridOp)
537 return op->emitError() << "Missing grid symbol.";
538 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
539 return op->emitError() << "Dynamic grid shape not supported.";
540 return gridOp;
541 }
542
543 // Get an MPI_Comm_split for a given grid and axes.
544 // The color is the linear index of the process in the grid along the
545 // non-'grid-axes'. The key is the linear index of the process in the grid
546 // along the grid-axes.
547 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
548 ImplicitLocOpBuilder &iBuilder) const {
549 size_t gridDims = gridOp.getShape().size();
550 auto commType = mpi::CommType::get(gridOp->getContext());
551 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
552
553 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
554 return commWorld;
555 }
556
557 SmallVector<GridAxis> otherAxes;
558 for (GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
559 if (!llvm::is_contained(gridAxes, i))
560 otherAxes.emplace_back(i);
561 }
562
563 SmallVector<Type> indexResultTypes(otherAxes.size(),
564 iBuilder.getIndexType());
565
566 Value color =
567 createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
568 color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
569
570 Value key =
571 createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
572 key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
573
574 // Finally split the communicator
575 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
576 .getNewcomm();
577 }
578};
579
580struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
581 using CommOpPattern::CommOpPattern;
582
583 LogicalResult
584 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const override {
586 SymbolTableCollection symbolTableCollection;
587 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
588 if (failed(gridOp))
589 return failure();
590 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
591 Value input = getAsMemref(adaptor.getInput(), iBuilder);
592 MemRefType inType = cast<MemRefType>(input.getType());
594 return op.emitError(
595 "Expected static shaped memref in contiguous row-major layout.");
596 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
598 return op.emitError(
599 "Expected static shaped memref in contiguous row-major layout.");
600
601 // Allocate buffer and copy input to buffer.
602 Value buffer = memref::AllocOp::create(iBuilder, outType);
603 linalg::CopyOp::create(iBuilder, input, buffer);
604 // Get the right communicator
605 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
606 // Create the MPI AllReduce operation.
607 mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
608 getMPIReductionOp(adaptor.getReductionAttr()),
609 comm);
610
611 // If the destination is a tensor, cast it to a tensor
612 if (isa<RankedTensorType>(op.getType()))
613 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
614 true);
615 rewriter.replaceOp(op, buffer);
616 return success();
617 }
618};
619
620struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
621 using CommOpPattern::CommOpPattern;
622
623 LogicalResult
624 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
625 ConversionPatternRewriter &rewriter) const override {
626 SymbolTableCollection symbolTableCollection;
627 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
628 if (failed(gridOp))
629 return failure();
630 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
631 Value input = getAsMemref(adaptor.getInput(), iBuilder);
632 MemRefType inType = cast<MemRefType>(input.getType());
634 return op.emitError(
635 "Expected static shaped memref in contiguous row-major layout.");
636 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
638 return op.emitError(
639 "Expected static shaped memref in contiguous row-major layout.");
640
641 // Get the right communicator
642 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
643 // Allocate output buffer
644 Value output = memref::AllocOp::create(iBuilder, outType);
645 // Create the MPI AllGather operation.
646 mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
647
648 // If the destination is a tensor, cast it to a tensor
649 if (isa<RankedTensorType>(op.getType()))
650 output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
651 true);
652 rewriter.replaceOp(op, output);
653 return success();
654 }
655};
656
657struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
658 using OpConversionPattern::OpConversionPattern;
659
660 LogicalResult
661 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
662 ConversionPatternRewriter &rewriter) const override {
663
664 // The input/output memref is assumed to be in C memory order.
665 // Halos are exchanged as 2 blocks per dimension (one for each side: down
666 // and up). For each haloed dimension `d`, the exchanged blocks are
667 // expressed as multi-dimensional subviews. The subviews include potential
668 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
669 // `dl < d` and for dimension `d` the currently exchanged halo only.
670 // By iterating form higher to lower dimensions this also updates the halos
671 // in the 'corners'.
672 // memref.subview is used to read and write the halo data from and to the
673 // local data. Because subviews and halos can have mixed dynamic and static
674 // shapes, OpFoldResults are used whenever possible.
675
676 auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
677 adaptor.getHaloSizes(), rewriter);
678 if (haloSizes.empty()) {
679 // no halos -> nothing to do
680 rewriter.replaceOp(op, adaptor.getDestination());
681 return success();
682 }
683
684 SymbolTableCollection symbolTableCollection;
685 Location loc = op.getLoc();
686
687 // convert a OpFoldResult into a Value
688 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
689 if (auto value = dyn_cast<Value>(v))
690 return value;
691 return arith::ConstantOp::create(
692 rewriter, loc,
693 rewriter.getIndexAttr(
694 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
695 };
696
697 auto dest = adaptor.getDestination();
698 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
699 Value array = dest;
700 if (isa<RankedTensorType>(array.getType())) {
701 // If the destination is a memref, we need to cast it to a tensor
702 auto mmemrefType = MemRefType::get(
703 dstShape, cast<ShapedType>(array.getType()).getElementType());
704 array =
705 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
706 }
707 auto rank = cast<ShapedType>(array.getType()).getRank();
708 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
709 auto grid = adaptor.getGrid();
710 auto gridOp = getGrid(op, symbolTableCollection);
711 // subviews need Index values
712 for (auto &sz : haloSizes) {
713 if (auto value = dyn_cast<Value>(sz))
714 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
715 value)
716 .getResult();
717 }
718
719 // most of the offset/size/stride data is the same for all dims
720 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
721 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
722 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
723 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
724 // we need the actual shape to compute offsets and sizes
725 for (auto i = 0; i < rank; ++i) {
726 auto s = dstShape[i];
727 if (ShapedType::isDynamic(s))
728 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
729 else
730 shape[i] = rewriter.getIndexAttr(s);
731
732 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
733 ++currHaloDim;
734 // the offsets for lower dim sstarts after their down halo
735 offsets[i] = haloSizes[currHaloDim * 2];
736
737 // prepare shape and offsets of highest dim's halo exchange
738 Value _haloSz = arith::AddIOp::create(
739 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
740 toValue(haloSizes[currHaloDim * 2 + 1]));
741 // the halo shape of lower dims exlude the halos
742 dimSizes[i] =
743 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
744 .getResult();
745 } else {
746 dimSizes[i] = shape[i];
747 }
748 }
749
750 auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
751 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
752 auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
753 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
754
755 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
756 rewriter.getIndexType());
757 auto myMultiIndex =
758 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
759 .getResult();
760 // traverse all split axes from high to low dim
761 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
762 auto splitAxes = opSplitAxes[dim];
763 if (splitAxes.empty())
764 continue;
765 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
766 // Get the linearized ids of the neighbors (down and up) for the
767 // given split
768 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
769 myMultiIndex, splitAxes)
770 .getResults();
771 // MPI operates on i32...
772 Value neighbourIDs[2] = {
773 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
774 tmp[0]),
775 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
776 tmp[1])};
777
778 auto lowerRecvOffset = rewriter.getIndexAttr(0);
779 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
780 auto upperRecvOffset =
781 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
782 toValue(haloSizes[currHaloDim * 2 + 1]));
783 auto upperSendOffset = arith::SubIOp::create(
784 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
785
786 Value commWorld = mpi::CommWorldOp::create(
787 rewriter, loc, mpi::CommType::get(op->getContext()));
788
789 // Make sure we send/recv in a way that does not lead to a dead-lock.
790 // The current approach is by far not optimal, this should be at least
791 // be a red-black pattern or using MPI_sendrecv.
792 // Also, buffers should be re-used.
793 // Still using temporary contiguous buffers for MPI communication...
794 // Still yielding a "serialized" communication pattern...
795 auto genSendRecv = [&](bool upperHalo) {
796 auto orgOffset = offsets[dim];
797 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
798 : haloSizes[currHaloDim * 2];
799 // Check if we need to send and/or receive
800 // Processes on the grid borders have only one neighbor
801 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
802 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
803 auto hasFrom = arith::CmpIOp::create(
804 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
805 auto hasTo = arith::CmpIOp::create(rewriter, loc,
806 arith::CmpIPredicate::sge, to, zero);
807 auto buffer = memref::AllocOp::create(
808 rewriter, loc, dimSizes,
809 cast<ShapedType>(array.getType()).getElementType());
810 // if has neighbor: copy halo data from array to buffer and send
811 scf::IfOp::create(
812 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
813 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
814 : OpFoldResult(upperSendOffset);
815 auto subview = memref::SubViewOp::create(
816 builder, loc, array, offsets, dimSizes, strides);
817 memref::CopyOp::create(builder, loc, subview, buffer);
818 mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
819 commWorld);
820 scf::YieldOp::create(builder, loc);
821 });
822 // if has neighbor: receive halo data into buffer and copy to array
823 scf::IfOp::create(
824 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
825 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
826 : OpFoldResult(lowerRecvOffset);
827 mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
828 commWorld);
829 auto subview = memref::SubViewOp::create(
830 builder, loc, array, offsets, dimSizes, strides);
831 memref::CopyOp::create(builder, loc, buffer, subview);
832 scf::YieldOp::create(builder, loc);
833 });
834 memref::DeallocOp::create(rewriter, loc, buffer);
835 offsets[dim] = orgOffset;
836 };
837
838 auto doSendRecv = [&](int upOrDown) {
839 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
840 Value haloSz = dyn_cast<Value>(v);
841 if (!haloSz)
842 haloSz = arith::ConstantOp::create(
843 rewriter, loc,
844 rewriter.getI32IntegerAttr(
845 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
846 auto hasSize = arith::CmpIOp::create(
847 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
848 scf::IfOp::create(rewriter, loc, hasSize,
849 [&](OpBuilder &builder, Location loc) {
850 genSendRecv(upOrDown > 0);
851 scf::YieldOp::create(builder, loc);
852 });
853 };
854
855 doSendRecv(0);
856 doSendRecv(1);
857
858 // the shape for lower dims include higher dims' halos
859 dimSizes[dim] = shape[dim];
860 // -> the offset for higher dims is always 0
861 offsets[dim] = rewriter.getIndexAttr(0);
862 // on to next halo
863 --currHaloDim;
864 }
865
866 if (isa<MemRefType>(op.getResult().getType())) {
867 rewriter.replaceOp(op, array);
868 } else {
869 assert(isa<RankedTensorType>(op.getResult().getType()));
870 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
871 rewriter, loc, op.getResult().getType(), array,
872 /*restrict=*/true, /*writable=*/true));
873 }
874 return success();
875 }
876};
877
878struct ConvertShardToMPIPass
879 : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
880 using Base::Base;
881
882 /// Run the dialect converter on the module.
883 void runOnOperation() override {
884 auto *ctxt = &getContext();
885 RewritePatternSet patterns(ctxt);
886 ConversionTarget target(getContext());
887
888 // Define a type converter to convert shard::ShardingType,
889 // mostly for use in return operations.
890 TypeConverter typeConverter;
891 typeConverter.addConversion([](Type type) { return type; });
892
893 // convert shard::ShardingType to a tuple of RankedTensorTypes
894 typeConverter.addConversion(
895 [](ShardingType type,
896 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
897 auto i16 = IntegerType::get(type.getContext(), 16);
898 auto i64 = IntegerType::get(type.getContext(), 64);
899 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
900 ShapedType::kDynamic};
901 results.emplace_back(RankedTensorType::get(shp, i16));
902 results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
903 results.emplace_back(RankedTensorType::get(shp, i64));
904 return success();
905 });
906
907 // To 'extract' components, a UnrealizedConversionCastOp is expected
908 // to define the input
909 typeConverter.addTargetMaterialization(
910 [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
911 Location loc) {
912 // Expecting a single input.
913 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
914 return SmallVector<Value>();
915 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
916 // Expecting an UnrealizedConversionCastOp.
917 if (!castOp)
918 return SmallVector<Value>();
919 // Fill a vector with elements of the tuple/castOp.
920 SmallVector<Value> results;
921 for (auto oprnd : castOp.getInputs()) {
922 if (!isa<RankedTensorType>(oprnd.getType()))
923 return SmallVector<Value>();
924 results.emplace_back(oprnd);
925 }
926 return results;
927 });
928
929 // No shard dialect should left after conversion...
930 target.addIllegalDialect<shard::ShardDialect>();
931 // ...except the global GridOp. GridShapeOp which will get folded later.
932 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
933 // Allow all the stuff that our patterns will convert to
934 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
935 arith::ArithDialect, tensor::TensorDialect,
936 bufferization::BufferizationDialect,
937 linalg::LinalgDialect, memref::MemRefDialect,
938 affine::AffineDialect, cf::ControlFlowDialect>();
939 // Make sure the function signature, calls etc. are legal
940 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
941 return typeConverter.isSignatureLegal(op.getFunctionType());
942 });
943 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
944 [&](Operation *op) { return typeConverter.isLegal(op); });
945
946 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
947 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
948 ConvertAllGatherOp, ConvertAllReduceOp,
949 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
950 SymbolTableCollection stc;
953
954 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
955 patterns, typeConverter);
958
959 (void)applyPartialConversion(getOperation(), target, std::move(patterns));
960
961 // Folding patterns cannot be mixed with conversion patterns -> extra pass.
962 patterns.clear();
963 SymbolTableCollection symbolTableCollection;
964 mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
965 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
966 }
967};
968
969} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
IntegerType getI32Type()
Definition Builders.cpp:63
IndexType getIndexType()
Definition Builders.cpp:51
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:207
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t GridAxis
Definition ShardOps.h:26
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:146
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns