MLIR 22.0.0git
ShardToMPI.cpp
Go to the documentation of this file.
1//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Shard communication ops to MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
14
30#include "mlir/IR/Builders.h"
34#include "mlir/IR/SymbolTable.h"
37
38#define DEBUG_TYPE "shard-to-mpi"
39
40namespace mlir {
41#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
42#include "mlir/Conversion/Passes.h.inc"
43} // namespace mlir
44
45using namespace mlir;
46using namespace shard;
47
48namespace {
49/// Converts a vector of OpFoldResults (ints) into vector of Values of the
50/// provided type.
53 ValueRange dynamics,
54 Type type = Type()) {
55 SmallVector<Value> values;
56 auto dyn = dynamics.begin();
57 Type i64 = b.getI64Type();
58 if (!type)
59 type = i64;
60 assert((i64 == type || b.getIndexType() == type) &&
61 "expected an i64 or an intex type");
62 for (auto s : statics) {
63 if (s == ShapedType::kDynamic) {
64 values.emplace_back(*(dyn++));
65 } else {
66 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
67 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
68 }
69 }
70 return values;
71}
72
73/// Create operations converting a linear index to a multi-dimensional index.
74static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
75 Value linearIndex,
76 ValueRange dimensions) {
77 int n = dimensions.size();
78 SmallVector<Value> multiIndex(n);
79
80 for (int i = n - 1; i >= 0; --i) {
81 multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
82 if (i > 0)
83 linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
84 }
85
86 return multiIndex;
87}
88
89/// Create operations converting a multi-dimensional index to a linear index.
90Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
91 ValueRange dimensions) {
92
93 Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
94 Value stride = arith::ConstantIndexOp::create(b, loc, 1);
95
96 for (int i = multiIndex.size() - 1; i >= 0; --i) {
97 Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
98 linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
99 stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
100 }
101
102 return linearIndex;
103}
104
105/// Replace GetShardingOp with related/dependent ShardingOp.
106struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
107 using OpConversionPattern::OpConversionPattern;
108
109 LogicalResult
110 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const override {
112 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
113 if (!shardOp)
114 return failure();
115 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
116 if (!shardingOp)
117 return failure();
118
119 rewriter.replaceOp(op, shardingOp.getResult());
120 return success();
121 }
122};
123
124/// Convert a sharding op to a tuple of tensors of its components
125/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
126/// as defined by type converter.
127struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
128 using OpConversionPattern::OpConversionPattern;
129
130 LogicalResult
131 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const override {
133 auto splitAxes = op.getSplitAxes().getAxes();
134 int64_t maxNAxes = 0;
135 for (auto axes : splitAxes)
136 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
137
138 // To hold the split axes, create empty 2d tensor with shape
139 // {splitAxes.size(), max-size-of-split-groups}.
140 // Set trailing elements for smaller split-groups to -1.
141 Location loc = op.getLoc();
142 auto i16 = rewriter.getI16Type();
143 auto i64 = rewriter.getI64Type();
144 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
145 maxNAxes};
146 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
147 auto attr = IntegerAttr::get(i16, -1);
148 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
149 resSplitAxes =
150 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
151 .getResult(0);
152
153 // explicitly write values into tensor row by row
154 std::array<int64_t, 2> strides = {1, 1};
155 int64_t nSplits = 0;
156 ValueRange empty = {};
157 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
158 int64_t size = axes.size();
159 if (size > 0)
160 ++nSplits;
161 std::array<int64_t, 2> offs = {(int64_t)i, 0};
162 std::array<int64_t, 2> sizes = {1, size};
163 auto tensorType = RankedTensorType::get({size}, i16);
164 auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
165 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
166 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
167 resSplitAxes, empty, empty,
168 empty, offs, sizes, strides);
169 }
170
171 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
172 // Store the halo sizes in the tensor.
173 SmallVector<Value> haloSizes =
174 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
175 adaptor.getDynamicHaloSizes());
176 auto type = RankedTensorType::get({nSplits, 2}, i64);
177 Value resHaloSizes =
178 haloSizes.empty()
179 ? tensor::EmptyOp::create(rewriter, loc,
180 std::array<int64_t, 2>{0, 0}, i64)
181 .getResult()
182 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
183 .getResult();
184
185 // To hold sharded dims offsets, create Tensor with shape {nSplits,
186 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
187 // elements for smaller split-groups to -1. Computing the max size of the
188 // split groups needs using collectiveProcessGroupSize (which needs the
189 // GridOp)
190 Value resOffsets;
191 if (adaptor.getStaticShardedDimsOffsets().empty()) {
192 resOffsets = tensor::EmptyOp::create(rewriter, loc,
193 std::array<int64_t, 2>{0, 0}, i64);
194 } else {
195 SymbolTableCollection symbolTableCollection;
196 auto gridOp = getGrid(op, symbolTableCollection);
197 int64_t maxSplitSize = 0;
198 for (auto axes : splitAxes) {
199 int64_t splitSize =
200 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
201 assert(splitSize != ShapedType::kDynamic);
202 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
203 }
204 assert(maxSplitSize);
205 ++maxSplitSize; // add one for the total size
206
207 resOffsets = tensor::EmptyOp::create(
208 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
209 Value zero = arith::ConstantOp::create(
210 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
211 resOffsets =
212 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
213 SmallVector<Value> offsets =
214 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
215 adaptor.getDynamicShardedDimsOffsets());
216 int64_t curr = 0;
217 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
218 int64_t splitSize =
219 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
220 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
221 ++splitSize; // add one for the total size
222 ArrayRef<Value> values(&offsets[curr], splitSize);
223 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
224 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
225 std::array<int64_t, 2> sizes = {1, splitSize};
226 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
227 resOffsets, empty, empty,
228 empty, offs, sizes, strides);
229 curr += splitSize;
230 }
231 }
232
233 // return a tuple of tensors as defined by type converter
234 SmallVector<Type> resTypes;
235 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
236 resTypes)))
237 return failure();
238
239 resSplitAxes =
240 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
241 resHaloSizes =
242 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
243 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
244
245 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
246 op, TupleType::get(op.getContext(), resTypes),
247 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
248
249 return success();
250 }
251};
252
253struct ConvertProcessMultiIndexOp
254 : public OpConversionPattern<ProcessMultiIndexOp> {
255 using OpConversionPattern::OpConversionPattern;
256
257 LogicalResult
258 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
259 ConversionPatternRewriter &rewriter) const override {
260
261 // Currently converts its linear index to a multi-dimensional index.
262
263 SymbolTableCollection symbolTableCollection;
264 Location loc = op.getLoc();
265 auto gridOp = getGrid(op, symbolTableCollection);
266 // For now we only support static grid shapes
267 if (ShapedType::isDynamicShape(gridOp.getShape()))
268 return failure();
269
270 SmallVector<Value> dims;
271 llvm::transform(
272 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
273 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
274 });
275 Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
276 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
277
278 // optionally extract subset of grid axes
279 auto axes = adaptor.getAxes();
280 if (!axes.empty()) {
281 SmallVector<Value> subIndex;
282 for (auto axis : axes) {
283 subIndex.emplace_back(mIdx[axis]);
284 }
285 mIdx = std::move(subIndex);
286 }
287
288 rewriter.replaceOp(op, mIdx);
289 return success();
290 }
291};
292
293class ConvertProcessLinearIndexOp
294 : public OpConversionPattern<ProcessLinearIndexOp> {
295
296public:
297 using OpConversionPattern::OpConversionPattern;
298
299 LogicalResult
300 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) const override {
302 // Create mpi::CommRankOp
303 Location loc = op.getLoc();
304 auto *ctx = op.getContext();
305 Value commWorld =
306 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
307 auto rank = mpi::CommRankOp::create(
308 rewriter, loc,
309 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
310 commWorld)
311 .getRank();
312 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
313 rank);
314 return success();
315 }
316};
317
318struct ConvertNeighborsLinearIndicesOp
319 : public OpConversionPattern<NeighborsLinearIndicesOp> {
320 using OpConversionPattern::OpConversionPattern;
321
322 LogicalResult
323 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter) const override {
325
326 // Computes the neighbors indices along a split axis by simply
327 // adding/subtracting 1 to the current index in that dimension.
328 // Assigns -1 if neighbor is out of bounds.
329
330 auto axes = adaptor.getSplitAxes();
331 // For now only single axis sharding is supported
332 if (axes.size() != 1)
333 return failure();
334
335 Location loc = op.getLoc();
336 SymbolTableCollection symbolTableCollection;
337 auto gridOp = getGrid(op, symbolTableCollection);
338 auto mIdx = adaptor.getDevice();
339 auto orgIdx = mIdx[axes[0]];
340 SmallVector<Value> dims;
341 llvm::transform(
342 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
343 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
344 });
345 Value dimSz = dims[axes[0]];
346 Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
347 Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
348 Value atBorder =
349 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
350 arith::ConstantIndexOp::create(rewriter, loc, 0));
351 auto down = scf::IfOp::create(
352 rewriter, loc, atBorder,
353 [&](OpBuilder &builder, Location loc) {
354 scf::YieldOp::create(builder, loc, minus1);
355 },
356 [&](OpBuilder &builder, Location loc) {
357 SmallVector<Value> tmp = mIdx;
358 tmp[axes[0]] =
359 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
360 .getResult();
361 scf::YieldOp::create(builder, loc,
362 multiToLinearIndex(loc, rewriter, tmp, dims));
363 });
364 atBorder = arith::CmpIOp::create(
365 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
366 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
367 auto up = scf::IfOp::create(
368 rewriter, loc, atBorder,
369 [&](OpBuilder &builder, Location loc) {
370 scf::YieldOp::create(builder, loc, minus1);
371 },
372 [&](OpBuilder &builder, Location loc) {
373 SmallVector<Value> tmp = mIdx;
374 tmp[axes[0]] =
375 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
376 scf::YieldOp::create(builder, loc,
377 multiToLinearIndex(loc, rewriter, tmp, dims));
378 });
379 rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
380 return success();
381 }
382};
383
384struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
385 using OpConversionPattern::OpConversionPattern;
386
387 LogicalResult
388 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
389 ConversionPatternRewriter &rewriter) const override {
390 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
391 if (!sharding) {
392 return op->emitError()
393 << "Expected ShardingOp as defining op for sharding"
394 << " but found " << adaptor.getSharding()[0].getDefiningOp();
395 }
396
397 // Compute the sharded shape by applying the sharding to the input shape.
398 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
399 // computed by dividing the dimension size by the number of shards in that
400 // dimension (which is given by the size of the grid axes provided in
401 // split-axes). Odd elements get distributed to trailing shards. If a
402 // shardedDimsOffsets is provided, the shard shape is computed by
403 // subtracting the offset of the current shard from the offset of the next
404 // shard.
405
406 Location loc = op.getLoc();
407 Type index = rewriter.getIndexType();
408
409 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
410 // The operands in the adaptor are a vector<ValeRange>. For dims and device
411 // we have a 1:1 conversion.
412 // For simpler access fill a vector with the dynamic dims.
413 SmallVector<Value> dynDims, dynDevice;
414 for (auto dim : adaptor.getDimsDynamic()) {
415 // type conversion should be 1:1 for ints
416 dynDims.emplace_back(llvm::getSingleElement(dim));
417 }
418 // same for device
419 for (auto device : adaptor.getDeviceDynamic()) {
420 dynDevice.emplace_back(llvm::getSingleElement(device));
421 }
422
423 // To keep the code simple, convert dims/device to values when they are
424 // attributes. Count on canonicalization to fold static values.
425 SmallVector<Value> shape =
426 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
427 SmallVector<Value> multiIdx =
428 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
429
430 // Get the GridOp, the grid shape is needed to compute the sharded shape.
431 SymbolTableCollection symbolTableCollection;
432 auto gridOp = getGrid(sharding, symbolTableCollection);
433 // For now we only support static grid shapes
434 if (ShapedType::isDynamicShape(gridOp.getShape()))
435 return failure();
436
437 auto splitAxes = sharding.getSplitAxes().getAxes();
438 // shardedDimsOffsets are optional and might be Values (not attributes).
439 // Also, the shardId might be dynamic which means the position in the
440 // shardedDimsOffsets is not statically known. Create a tensor of the
441 // shardedDimsOffsets and later extract the offsets for computing the
442 // local shard-size.
443 Value shardedDimsOffs;
444 {
445 SmallVector<Value> tmp = getMixedAsValues(
446 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
447 sharding.getDynamicShardedDimsOffsets(), index);
448 if (!tmp.empty())
449 shardedDimsOffs = tensor::FromElementsOp::create(
450 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
451 tmp);
452 }
453
454 // With static grid shape the sizes of the split axes are known.
455 // Hence the start/pos for each split axes in shardDimsOffsets can be
456 // computed statically.
457 int64_t pos = 0;
458 SmallVector<Value> shardShape;
459 Value zero =
460 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
461 Value one =
462 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
463
464 // Iterate over the dimensions of the tensor shape, get their split Axes,
465 // and compute the sharded shape.
466 for (auto [i, dim] : llvm::enumerate(shape)) {
467 // Trailing dimensions might not be annotated.
468 if (i < splitAxes.size() && !splitAxes[i].empty()) {
469 auto axes = splitAxes[i];
470 // The current dimension might not be sharded.
471 // Create a value from the static position in shardDimsOffsets.
472 Value posVal = arith::ConstantOp::create(rewriter, loc,
473 rewriter.getIndexAttr(pos));
474 // Get the index of the local shard in the grid axis.
475 Value idx = multiIdx[axes[0]];
476 auto numShards =
477 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
478 if (shardedDimsOffs) {
479 // If sharded dims offsets are provided, use them to compute the
480 // sharded shape.
481 if (axes.size() > 1) {
482 return op->emitError() << "Only single axis sharding is "
483 << "supported for each dimension.";
484 }
485 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
486 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
487 Value off =
488 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
489 idx = arith::AddIOp::create(rewriter, loc, idx, one);
490 Value nextOff =
491 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
492 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
493 shardShape.emplace_back(sz);
494 } else {
495 Value numShardsVal = arith::ConstantOp::create(
496 rewriter, loc, rewriter.getIndexAttr(numShards));
497 // Compute shard dim size by distributing odd elements to trailing
498 // shards:
499 // sz = dim / numShards
500 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
501 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
502 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
503 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
504 auto cond = arith::CmpIOp::create(
505 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
506 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
507 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
508 shardShape.emplace_back(sz);
509 }
510 pos += numShards + 1; // add one for the total size.
511 } // else no sharding if split axis is empty or no split axis
512 // If no size was added -> no sharding in this dimension.
513 if (shardShape.size() <= i)
514 shardShape.emplace_back(dim);
515 }
516 assert(shardShape.size() == shape.size());
517 rewriter.replaceOp(op, shardShape);
518 return success();
519 }
520};
521
522static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523 auto *ctx = kind.getContext();
524 auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
525 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
526 };
527
528 switch (kind.getValue()) {
529 case ReductionKind::Sum:
530 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
531 case ReductionKind::Product:
532 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
533 case ReductionKind::Min:
534 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
535 case ReductionKind::Max:
536 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
537 case ReductionKind::BitwiseAnd:
538 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
539 case ReductionKind::BitwiseOr:
540 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
541 case ReductionKind::BitwiseXor:
542 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
543 default:
544 llvm_unreachable("Unknown/unsupported reduction kind");
545 }
546}
547
548struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
549 using OpConversionPattern::OpConversionPattern;
550
551 LogicalResult
552 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter) const override {
554 SymbolTableCollection symbolTableCollection;
555 auto grid = adaptor.getGrid();
556 mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
557 if (!gridOp)
558 return op->emitError() << "No grid found for AllReduceOp";
559 if (ShapedType::isDynamicShape(gridOp.getShape()))
560 return op->emitError()
561 << "Dynamic grid shape not supported in AllReduceOp";
562
563 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
564 Value input = adaptor.getInput();
565 auto inputShape = cast<ShapedType>(input.getType()).getShape();
566
567 // If the source is a memref, cast it to a tensor.
568 if (isa<RankedTensorType>(input.getType())) {
569 auto memrefType = MemRefType::get(
570 inputShape, cast<ShapedType>(input.getType()).getElementType());
571 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
572 }
573 MemRefType inType = cast<MemRefType>(input.getType());
574
575 // Get the actual shape to allocate the buffer.
576 SmallVector<OpFoldResult> shape(inType.getRank());
577 for (auto i = 0; i < inType.getRank(); ++i) {
578 auto s = inputShape[i];
579 if (ShapedType::isDynamic(s))
580 shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
581 else
582 shape[i] = iBuilder.getIndexAttr(s);
583 }
584
585 // Allocate buffer and copy input to buffer.
586 Value buffer = memref::AllocOp::create(
587 iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
588 linalg::CopyOp::create(iBuilder, input, buffer);
589
590 // Get an MPI_Comm_split for the AllReduce operation.
591 // The color is the linear index of the process in the grid along the
592 // non-reduced axes. The key is the linear index of the process in the grid
593 // along the reduced axes.
594 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
595 iBuilder.getIndexType());
596 SmallVector<Value> myMultiIndex =
597 ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
598 .getResult();
599 Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
600 SmallVector<Value> multiKey(myMultiIndex.size(), zero);
601
602 auto redAxes = adaptor.getGridAxes();
603 for (auto axis : redAxes) {
604 multiKey[axis] = myMultiIndex[axis];
605 myMultiIndex[axis] = zero;
606 }
607
608 Value color =
609 createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
610 color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
611 Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
612 key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
613
614 // Finally split the communicator
615 auto commType = mpi::CommType::get(op->getContext());
616 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
617 auto comm =
618 mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
619 .getNewcomm();
620
621 Value buffer1d = buffer;
622 // Collapse shape to 1d if needed
623 if (inType.getRank() > 1) {
624 ReassociationIndices reassociation(inType.getRank());
625 std::iota(reassociation.begin(), reassociation.end(), 0);
626 buffer1d = memref::CollapseShapeOp::create(
627 iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
628 }
629
630 // Create the MPI AllReduce operation.
631 mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
632 getMPIReductionOp(adaptor.getReductionAttr()),
633 comm);
634
635 // If the destination is a memref, cast it to a tensor
636 if (isa<RankedTensorType>(op.getType()))
637 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
638 true);
639
640 rewriter.replaceOp(op, buffer);
641 return success();
642 }
643};
644
645struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
646 using OpConversionPattern::OpConversionPattern;
647
648 LogicalResult
649 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651
652 // The input/output memref is assumed to be in C memory order.
653 // Halos are exchanged as 2 blocks per dimension (one for each side: down
654 // and up). For each haloed dimension `d`, the exchanged blocks are
655 // expressed as multi-dimensional subviews. The subviews include potential
656 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
657 // `dl < d` and for dimension `d` the currently exchanged halo only.
658 // By iterating form higher to lower dimensions this also updates the halos
659 // in the 'corners'.
660 // memref.subview is used to read and write the halo data from and to the
661 // local data. Because subviews and halos can have mixed dynamic and static
662 // shapes, OpFoldResults are used whenever possible.
663
664 auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
665 adaptor.getHaloSizes(), rewriter);
666 if (haloSizes.empty()) {
667 // no halos -> nothing to do
668 rewriter.replaceOp(op, adaptor.getDestination());
669 return success();
670 }
671
672 SymbolTableCollection symbolTableCollection;
673 Location loc = op.getLoc();
674
675 // convert a OpFoldResult into a Value
676 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
677 if (auto value = dyn_cast<Value>(v))
678 return value;
679 return arith::ConstantOp::create(
680 rewriter, loc,
681 rewriter.getIndexAttr(
682 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
683 };
684
685 auto dest = adaptor.getDestination();
686 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
687 Value array = dest;
688 if (isa<RankedTensorType>(array.getType())) {
689 // If the destination is a memref, we need to cast it to a tensor
690 auto mmemrefType = MemRefType::get(
691 dstShape, cast<ShapedType>(array.getType()).getElementType());
692 array =
693 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
694 }
695 auto rank = cast<ShapedType>(array.getType()).getRank();
696 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
697 auto grid = adaptor.getGrid();
698 auto gridOp = getGrid(op, symbolTableCollection);
699 // subviews need Index values
700 for (auto &sz : haloSizes) {
701 if (auto value = dyn_cast<Value>(sz))
702 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
703 value)
704 .getResult();
705 }
706
707 // most of the offset/size/stride data is the same for all dims
708 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
709 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
710 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
711 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
712 // we need the actual shape to compute offsets and sizes
713 for (auto i = 0; i < rank; ++i) {
714 auto s = dstShape[i];
715 if (ShapedType::isDynamic(s))
716 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
717 else
718 shape[i] = rewriter.getIndexAttr(s);
719
720 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
721 ++currHaloDim;
722 // the offsets for lower dim sstarts after their down halo
723 offsets[i] = haloSizes[currHaloDim * 2];
724
725 // prepare shape and offsets of highest dim's halo exchange
726 Value _haloSz = arith::AddIOp::create(
727 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
728 toValue(haloSizes[currHaloDim * 2 + 1]));
729 // the halo shape of lower dims exlude the halos
730 dimSizes[i] =
731 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
732 .getResult();
733 } else {
734 dimSizes[i] = shape[i];
735 }
736 }
737
738 auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
739 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
740 auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
741 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
742
743 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
744 rewriter.getIndexType());
745 auto myMultiIndex =
746 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
747 .getResult();
748 // traverse all split axes from high to low dim
749 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
750 auto splitAxes = opSplitAxes[dim];
751 if (splitAxes.empty())
752 continue;
753 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
754 // Get the linearized ids of the neighbors (down and up) for the
755 // given split
756 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
757 myMultiIndex, splitAxes)
758 .getResults();
759 // MPI operates on i32...
760 Value neighbourIDs[2] = {
761 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
762 tmp[0]),
763 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
764 tmp[1])};
765
766 auto lowerRecvOffset = rewriter.getIndexAttr(0);
767 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
768 auto upperRecvOffset =
769 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
770 toValue(haloSizes[currHaloDim * 2 + 1]));
771 auto upperSendOffset = arith::SubIOp::create(
772 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
773
774 Value commWorld = mpi::CommWorldOp::create(
775 rewriter, loc, mpi::CommType::get(op->getContext()));
776
777 // Make sure we send/recv in a way that does not lead to a dead-lock.
778 // The current approach is by far not optimal, this should be at least
779 // be a red-black pattern or using MPI_sendrecv.
780 // Also, buffers should be re-used.
781 // Still using temporary contiguous buffers for MPI communication...
782 // Still yielding a "serialized" communication pattern...
783 auto genSendRecv = [&](bool upperHalo) {
784 auto orgOffset = offsets[dim];
785 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
786 : haloSizes[currHaloDim * 2];
787 // Check if we need to send and/or receive
788 // Processes on the grid borders have only one neighbor
789 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
790 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
791 auto hasFrom = arith::CmpIOp::create(
792 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
793 auto hasTo = arith::CmpIOp::create(rewriter, loc,
794 arith::CmpIPredicate::sge, to, zero);
795 auto buffer = memref::AllocOp::create(
796 rewriter, loc, dimSizes,
797 cast<ShapedType>(array.getType()).getElementType());
798 // if has neighbor: copy halo data from array to buffer and send
799 scf::IfOp::create(
800 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
801 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
802 : OpFoldResult(upperSendOffset);
803 auto subview = memref::SubViewOp::create(
804 builder, loc, array, offsets, dimSizes, strides);
805 memref::CopyOp::create(builder, loc, subview, buffer);
806 mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
807 commWorld);
808 scf::YieldOp::create(builder, loc);
809 });
810 // if has neighbor: receive halo data into buffer and copy to array
811 scf::IfOp::create(
812 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
813 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
814 : OpFoldResult(lowerRecvOffset);
815 mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
816 commWorld);
817 auto subview = memref::SubViewOp::create(
818 builder, loc, array, offsets, dimSizes, strides);
819 memref::CopyOp::create(builder, loc, buffer, subview);
820 scf::YieldOp::create(builder, loc);
821 });
822 memref::DeallocOp::create(rewriter, loc, buffer);
823 offsets[dim] = orgOffset;
824 };
825
826 auto doSendRecv = [&](int upOrDown) {
827 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
828 Value haloSz = dyn_cast<Value>(v);
829 if (!haloSz)
830 haloSz = arith::ConstantOp::create(
831 rewriter, loc,
832 rewriter.getI32IntegerAttr(
833 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
834 auto hasSize = arith::CmpIOp::create(
835 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
836 scf::IfOp::create(rewriter, loc, hasSize,
837 [&](OpBuilder &builder, Location loc) {
838 genSendRecv(upOrDown > 0);
839 scf::YieldOp::create(builder, loc);
840 });
841 };
842
843 doSendRecv(0);
844 doSendRecv(1);
845
846 // the shape for lower dims include higher dims' halos
847 dimSizes[dim] = shape[dim];
848 // -> the offset for higher dims is always 0
849 offsets[dim] = rewriter.getIndexAttr(0);
850 // on to next halo
851 --currHaloDim;
852 }
853
854 if (isa<MemRefType>(op.getResult().getType())) {
855 rewriter.replaceOp(op, array);
856 } else {
857 assert(isa<RankedTensorType>(op.getResult().getType()));
858 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
859 rewriter, loc, op.getResult().getType(), array,
860 /*restrict=*/true, /*writable=*/true));
861 }
862 return success();
863 }
864};
865
866struct ConvertShardToMPIPass
867 : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
868 using Base::Base;
869
870 /// Run the dialect converter on the module.
871 void runOnOperation() override {
872 auto *ctxt = &getContext();
873 RewritePatternSet patterns(ctxt);
874 ConversionTarget target(getContext());
875
876 // Define a type converter to convert shard::ShardingType,
877 // mostly for use in return operations.
878 TypeConverter typeConverter;
879 typeConverter.addConversion([](Type type) { return type; });
880
881 // convert shard::ShardingType to a tuple of RankedTensorTypes
882 typeConverter.addConversion(
883 [](ShardingType type,
884 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
885 auto i16 = IntegerType::get(type.getContext(), 16);
886 auto i64 = IntegerType::get(type.getContext(), 64);
887 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
888 ShapedType::kDynamic};
889 results.emplace_back(RankedTensorType::get(shp, i16));
890 results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
891 results.emplace_back(RankedTensorType::get(shp, i64));
892 return success();
893 });
894
895 // To 'extract' components, a UnrealizedConversionCastOp is expected
896 // to define the input
897 typeConverter.addTargetMaterialization(
898 [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
899 Location loc) {
900 // Expecting a single input.
901 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
902 return SmallVector<Value>();
903 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
904 // Expecting an UnrealizedConversionCastOp.
905 if (!castOp)
906 return SmallVector<Value>();
907 // Fill a vector with elements of the tuple/castOp.
908 SmallVector<Value> results;
909 for (auto oprnd : castOp.getInputs()) {
910 if (!isa<RankedTensorType>(oprnd.getType()))
911 return SmallVector<Value>();
912 results.emplace_back(oprnd);
913 }
914 return results;
915 });
916
917 // No shard dialect should left after conversion...
918 target.addIllegalDialect<shard::ShardDialect>();
919 // ...except the global GridOp. GridShapeOp which will get folded later.
920 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
921 // Allow all the stuff that our patterns will convert to
922 target.addLegalDialect<
923 BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
924 tensor::TensorDialect, bufferization::BufferizationDialect,
925 linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
926 // Make sure the function signature, calls etc. are legal
927 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
928 return typeConverter.isSignatureLegal(op.getFunctionType());
929 });
930 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
931 [&](Operation *op) { return typeConverter.isLegal(op); });
932
933 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
934 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
935 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
936 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
937
938 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
939 patterns, typeConverter);
942
943 (void)applyPartialConversion(getOperation(), target, std::move(patterns));
944
945 // Folding patterns cannot be mixed with conversion patterns -> extra pass.
946 patterns.clear();
947 SymbolTableCollection symbolTableCollection;
948 mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
949 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
950 }
951};
952
953} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:207
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:146
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:121
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27