MLIR 23.0.0git
ShardToMPI.cpp
Go to the documentation of this file.
1//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Shard communication ops to MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
15
33#include "mlir/IR/AffineMap.h"
34#include "mlir/IR/Builders.h"
38#include "mlir/IR/SymbolTable.h"
41
42#define DEBUG_TYPE "shard-to-mpi"
43
44namespace mlir {
45#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50using namespace shard;
51
52namespace {
53/// Converts a vector of OpFoldResults (ints) into vector of Values of the
54/// provided type.
57 ValueRange dynamics,
58 Type type = Type()) {
59 SmallVector<Value> values;
60 auto dyn = dynamics.begin();
61 Type i64 = b.getI64Type();
62 if (!type)
63 type = i64;
64 assert((i64 == type || b.getIndexType() == type) &&
65 "expected an i64 or an intex type");
66 for (auto s : statics) {
67 if (s == ShapedType::kDynamic) {
68 values.emplace_back(*(dyn++));
69 } else {
70 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
71 values.emplace_back(arith::ConstantOp::create(b, loc, type, val));
72 }
73 }
74 return values;
75}
76
77/// Create operations converting a linear index to a multi-dimensional index.
78[[maybe_unused]] static SmallVector<Value>
79linearToMultiIndex(Location loc, OpBuilder b, Value linearIndex,
80 ValueRange dimensions) {
81 int n = dimensions.size();
82 SmallVector<Value> multiIndex(n);
83
84 for (int i = n - 1; i >= 0; --i) {
85 multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]);
86 if (i > 0)
87 linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]);
88 }
89
90 return multiIndex;
91}
92
93/// Create operations converting a multi-dimensional index to a linear index.
94Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
95 ValueRange dimensions) {
96
97 Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0);
98 Value stride = arith::ConstantIndexOp::create(b, loc, 1);
99
100 for (int i = multiIndex.size() - 1; i >= 0; --i) {
101 Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride);
102 linearIndex = arith::AddIOp::create(b, loc, linearIndex, off);
103 stride = arith::MulIOp::create(b, loc, stride, dimensions[i]);
104 }
105
106 return linearIndex;
107}
108
109/// Replace GetShardingOp with related/dependent ShardingOp.
110struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
111 using OpConversionPattern::OpConversionPattern;
112
113 LogicalResult
114 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
117 if (!shardOp)
118 return failure();
119 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
120 if (!shardingOp)
121 return failure();
122
123 rewriter.replaceOp(op, shardingOp.getResult());
124 return success();
125 }
126};
127
128/// Convert a sharding op to a tuple of tensors of its components
129/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
130/// as defined by type converter.
131struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
132 using OpConversionPattern::OpConversionPattern;
133
134 LogicalResult
135 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 auto splitAxes = op.getSplitAxes().getAxes();
138 int64_t maxNAxes = 0;
139 for (auto axes : splitAxes)
140 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
141
142 // To hold the split axes, create empty 2d tensor with shape
143 // {splitAxes.size(), max-size-of-split-groups}.
144 // Set trailing elements for smaller split-groups to -1.
145 Location loc = op.getLoc();
146 auto i16 = rewriter.getI16Type();
147 auto i64 = rewriter.getI64Type();
148 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
149 maxNAxes};
150 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
151 auto attr = IntegerAttr::get(i16, -1);
152 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
153 resSplitAxes =
154 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
155 .getResult(0);
156
157 // explicitly write values into tensor row by row
158 std::array<int64_t, 2> strides = {1, 1};
159 int64_t nSplits = 0;
160 ValueRange empty = {};
161 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
162 int64_t size = axes.size();
163 if (size > 0)
164 ++nSplits;
165 std::array<int64_t, 2> offs = {(int64_t)i, 0};
166 std::array<int64_t, 2> sizes = {1, size};
167 auto tensorType = RankedTensorType::get({size}, i16);
168 auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
169 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
170 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
171 resSplitAxes, empty, empty,
172 empty, offs, sizes, strides);
173 }
174
175 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
176 // Store the halo sizes in the tensor.
177 SmallVector<Value> haloSizes =
178 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
179 adaptor.getDynamicHaloSizes());
180 auto type = RankedTensorType::get({nSplits, 2}, i64);
181 Value resHaloSizes =
182 haloSizes.empty()
183 ? tensor::EmptyOp::create(rewriter, loc,
184 std::array<int64_t, 2>{0, 0}, i64)
185 .getResult()
186 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
187 .getResult();
188
189 // To hold sharded dims offsets, create Tensor with shape {nSplits,
190 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
191 // elements for smaller split-groups to -1. Computing the max size of the
192 // split groups needs using collectiveProcessGroupSize (which needs the
193 // GridOp)
194 Value resOffsets;
195 if (adaptor.getStaticShardedDimsOffsets().empty()) {
196 resOffsets = tensor::EmptyOp::create(rewriter, loc,
197 std::array<int64_t, 2>{0, 0}, i64);
198 } else {
199 SymbolTableCollection symbolTableCollection;
200 auto gridOp = getGrid(op, symbolTableCollection);
201 int64_t maxSplitSize = 0;
202 for (auto axes : splitAxes) {
203 int64_t splitSize =
204 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
205 assert(splitSize != ShapedType::kDynamic);
206 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
207 }
208 assert(maxSplitSize);
209 ++maxSplitSize; // add one for the total size
210
211 resOffsets = tensor::EmptyOp::create(
212 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
213 Value zero = arith::ConstantOp::create(
214 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
215 resOffsets =
216 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
217 SmallVector<Value> offsets =
218 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
219 adaptor.getDynamicShardedDimsOffsets());
220 int64_t curr = 0;
221 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
222 int64_t splitSize =
223 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
224 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
225 ++splitSize; // add one for the total size
226 ArrayRef<Value> values(&offsets[curr], splitSize);
227 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
228 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
229 std::array<int64_t, 2> sizes = {1, splitSize};
230 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
231 resOffsets, empty, empty,
232 empty, offs, sizes, strides);
233 curr += splitSize;
234 }
235 }
236
237 // return a tuple of tensors as defined by type converter
238 SmallVector<Type> resTypes;
239 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
240 resTypes)))
241 return failure();
242
243 resSplitAxes =
244 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
245 resHaloSizes =
246 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
247 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
248
249 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
250 op, TupleType::get(op.getContext(), resTypes),
251 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
252
253 return success();
254 }
255};
256
257class ConvertProcessLinearIndexOp
258 : public OpConversionPattern<ProcessLinearIndexOp> {
259
260public:
261 using OpConversionPattern::OpConversionPattern;
262
263 LogicalResult
264 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override {
266 // Create mpi::CommRankOp
267 Location loc = op.getLoc();
268 auto *ctx = op.getContext();
269 Value commWorld =
270 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
271 auto rank = mpi::CommRankOp::create(
272 rewriter, loc,
273 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
274 commWorld)
275 .getRank();
276 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
277 rank);
278 return success();
279 }
280};
281
282struct ConvertNeighborsLinearIndicesOp
283 : public OpConversionPattern<NeighborsLinearIndicesOp> {
284 using OpConversionPattern::OpConversionPattern;
285
286 LogicalResult
287 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289
290 // Computes the neighbors indices along a split axis by simply
291 // adding/subtracting 1 to the current index in that dimension.
292 // Assigns -1 if neighbor is out of bounds.
293
294 auto axes = adaptor.getSplitAxes();
295 // For now only single axis sharding is supported
296 if (axes.size() != 1)
297 return failure();
298
299 Location loc = op.getLoc();
300 SymbolTableCollection symbolTableCollection;
301 auto gridOp = getGrid(op, symbolTableCollection);
302 auto mIdx = adaptor.getDevice();
303 auto orgIdx = mIdx[axes[0]];
304 SmallVector<Value> dims;
305 llvm::transform(
306 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
307 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
308 });
309 Value dimSz = dims[axes[0]];
310 Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
311 Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1);
312 Value atBorder =
313 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
314 arith::ConstantIndexOp::create(rewriter, loc, 0));
315 auto down = scf::IfOp::create(
316 rewriter, loc, atBorder,
317 [&](OpBuilder &builder, Location loc) {
318 scf::YieldOp::create(builder, loc, minus1);
319 },
320 [&](OpBuilder &builder, Location loc) {
321 SmallVector<Value> tmp = mIdx;
322 tmp[axes[0]] =
323 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
324 .getResult();
325 scf::YieldOp::create(builder, loc,
326 multiToLinearIndex(loc, rewriter, tmp, dims));
327 });
328 atBorder = arith::CmpIOp::create(
329 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
330 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
331 auto up = scf::IfOp::create(
332 rewriter, loc, atBorder,
333 [&](OpBuilder &builder, Location loc) {
334 scf::YieldOp::create(builder, loc, minus1);
335 },
336 [&](OpBuilder &builder, Location loc) {
337 SmallVector<Value> tmp = mIdx;
338 tmp[axes[0]] =
339 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
340 scf::YieldOp::create(builder, loc,
341 multiToLinearIndex(loc, rewriter, tmp, dims));
342 });
343 rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
344 return success();
345 }
346};
347
348struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
349 using OpConversionPattern::OpConversionPattern;
350
351 LogicalResult
352 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
355 if (!sharding) {
356 return op->emitError()
357 << "Expected ShardingOp as defining op for sharding"
358 << " but found " << adaptor.getSharding()[0].getDefiningOp();
359 }
360
361 // Compute the sharded shape by applying the sharding to the input shape.
362 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
363 // computed by dividing the dimension size by the number of shards in that
364 // dimension (which is given by the size of the grid axes provided in
365 // split-axes). Odd elements get distributed to trailing shards. If a
366 // shardedDimsOffsets is provided, the shard shape is computed by
367 // subtracting the offset of the current shard from the offset of the next
368 // shard.
369
370 Location loc = op.getLoc();
371 Type index = rewriter.getIndexType();
372
373 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
374 // The operands in the adaptor are a vector<ValeRange>. For dims and device
375 // we have a 1:1 conversion.
376 // For simpler access fill a vector with the dynamic dims.
377 SmallVector<Value> dynDims, dynDevice;
378 for (auto dim : adaptor.getDimsDynamic()) {
379 // type conversion should be 1:1 for ints
380 dynDims.emplace_back(llvm::getSingleElement(dim));
381 }
382 // same for device
383 for (auto device : adaptor.getDeviceDynamic()) {
384 dynDevice.emplace_back(llvm::getSingleElement(device));
385 }
386
387 // To keep the code simple, convert dims/device to values when they are
388 // attributes. Count on canonicalization to fold static values.
389 SmallVector<Value> shape =
390 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
391 SmallVector<Value> multiIdx =
392 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
393
394 // Get the GridOp, the grid shape is needed to compute the sharded shape.
395 SymbolTableCollection symbolTableCollection;
396 auto gridOp = getGrid(sharding, symbolTableCollection);
397 // For now we only support static grid shapes
398 if (ShapedType::isDynamicShape(gridOp.getShape()))
399 return failure();
400
401 auto splitAxes = sharding.getSplitAxes().getAxes();
402 // shardedDimsOffsets are optional and might be Values (not attributes).
403 // Also, the shardId might be dynamic which means the position in the
404 // shardedDimsOffsets is not statically known. Create a tensor of the
405 // shardedDimsOffsets and later extract the offsets for computing the
406 // local shard-size.
407 Value shardedDimsOffs;
408 {
409 SmallVector<Value> tmp = getMixedAsValues(
410 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
411 sharding.getDynamicShardedDimsOffsets(), index);
412 if (!tmp.empty())
413 shardedDimsOffs = tensor::FromElementsOp::create(
414 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
415 tmp);
416 }
417
418 // With static grid shape the sizes of the split axes are known.
419 // Hence the start/pos for each split axes in shardDimsOffsets can be
420 // computed statically.
421 int64_t pos = 0;
422 SmallVector<Value> shardShape;
423 Value zero =
424 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
425 Value one =
426 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
427
428 // Iterate over the dimensions of the tensor shape, get their split Axes,
429 // and compute the sharded shape.
430 for (auto [i, dim] : llvm::enumerate(shape)) {
431 // Trailing dimensions might not be annotated.
432 if (i < splitAxes.size() && !splitAxes[i].empty()) {
433 auto axes = splitAxes[i];
434 // The current dimension might not be sharded.
435 // Create a value from the static position in shardDimsOffsets.
436 Value posVal = arith::ConstantOp::create(rewriter, loc,
437 rewriter.getIndexAttr(pos));
438 // Get the index of the local shard in the grid axis.
439 Value idx = multiIdx[axes[0]];
440 auto numShards =
441 collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
442 if (shardedDimsOffs) {
443 // If sharded dims offsets are provided, use them to compute the
444 // sharded shape.
445 if (axes.size() > 1) {
446 return op->emitError() << "Only single axis sharding is "
447 << "supported for each dimension.";
448 }
449 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
450 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
451 Value off =
452 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
453 idx = arith::AddIOp::create(rewriter, loc, idx, one);
454 Value nextOff =
455 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
456 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
457 shardShape.emplace_back(sz);
458 } else {
459 Value numShardsVal = arith::ConstantOp::create(
460 rewriter, loc, rewriter.getIndexAttr(numShards));
461 // Compute shard dim size by distributing odd elements to trailing
462 // shards:
463 // sz = dim / numShards
464 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
465 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
466 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
467 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
468 auto cond = arith::CmpIOp::create(
469 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
470 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
471 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
472 shardShape.emplace_back(sz);
473 }
474 pos += numShards + 1; // add one for the total size.
475 } // else no sharding if split axis is empty or no split axis
476 // If no size was added -> no sharding in this dimension.
477 if (shardShape.size() <= i)
478 shardShape.emplace_back(dim);
479 }
480 assert(shardShape.size() == shape.size());
481 rewriter.replaceOp(op, shardShape);
482 return success();
483 }
484};
485
486static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
487 auto *ctx = kind.getContext();
488 auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
489 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
490 };
491
492 switch (kind.getValue()) {
493 case ReductionKind::Sum:
494 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
495 case ReductionKind::Product:
496 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
497 case ReductionKind::Min:
498 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
499 case ReductionKind::Max:
500 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
501 case ReductionKind::BitwiseAnd:
502 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
503 case ReductionKind::BitwiseOr:
504 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
505 case ReductionKind::BitwiseXor:
506 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
507 default:
508 llvm_unreachable("Unknown/unsupported reduction kind");
509 }
510}
511
512template <typename CommOp>
513struct CommOpPattern : public OpConversionPattern<CommOp> {
514 using OpConversionPattern<CommOp>::OpConversionPattern;
515
516 MemRefType getMemrefType(ShapedType tensorType) const {
517 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
518 }
519
520 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder,
521 bool readOnly) const {
522 auto itype = input.getType();
523 // If the source is a tensor, materialize a memref for it.
524 if (isa<RankedTensorType>(itype)) {
525 auto memrefType = getMemrefType(cast<ShapedType>(itype));
526 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input,
527 readOnly);
528 } else {
529 assert(isa<MemRefType>(itype) &&
530 "expected input to be of MemRefType or TensorType");
531 }
532 return input;
533 }
534
535 FailureOr<GridOp> checkGrid(CommOp op,
536 SymbolTableCollection &symbolTableCollection,
537 bool allowDynamic = false) const {
538 GridOp gridOp = getGrid(op, symbolTableCollection);
539 if (!gridOp)
540 return op->emitError() << "Missing grid symbol.";
541 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
542 return op->emitError() << "Dynamic grid shape not supported.";
543 return gridOp;
544 }
545
546 // Get an MPI_Comm_split for a given grid and axes.
547 // The color is the linear index of the process in the grid along the
548 // non-'grid-axes'. The key is the linear index of the process in the grid
549 // along the grid-axes.
550 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
551 ImplicitLocOpBuilder &iBuilder) const {
552 size_t gridDims = gridOp.getShape().size();
553 auto commType = mpi::CommType::get(gridOp->getContext());
554 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
555
556 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
557 return commWorld;
558 }
559
560 SmallVector<GridAxis> otherAxes;
561 for (GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
562 if (!llvm::is_contained(gridAxes, i))
563 otherAxes.emplace_back(i);
564 }
565
566 SmallVector<Type> indexResultTypes(otherAxes.size(),
567 iBuilder.getIndexType());
568
569 Value color =
570 createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
571 color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
572
573 Value key =
574 createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
575 key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
576
577 // Finally split the communicator
578 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
579 .getNewcomm();
580 }
581};
582
583struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
584 using CommOpPattern::CommOpPattern;
585
586 LogicalResult
587 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 SymbolTableCollection symbolTableCollection;
590 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
591 if (failed(gridOp))
592 return failure();
593 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
594 Value input = getAsMemref(adaptor.getInput(), iBuilder, true);
595 MemRefType inType = cast<MemRefType>(input.getType());
597 return op.emitError(
598 "Expected static shaped memref in contiguous row-major layout.");
599 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
601 return op.emitError(
602 "Expected static shaped memref in contiguous row-major layout.");
603
604 // Allocate buffer and copy input to buffer.
605 Value buffer = memref::AllocOp::create(iBuilder, outType);
606 linalg::CopyOp::create(iBuilder, input, buffer);
607 // Get the right communicator
608 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
609 // Create the MPI AllReduce operation.
610 mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
611 getMPIReductionOp(adaptor.getReductionAttr()),
612 comm);
613
614 // If the destination is a tensor, cast it to a tensor
615 if (isa<RankedTensorType>(op.getType()))
616 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
617 true);
618 rewriter.replaceOp(op, buffer);
619 return success();
620 }
621};
622
623struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
624 using CommOpPattern::CommOpPattern;
625
626 // shard.reduce_scatter reduces and then scatters along a specified
627 // scatter-dim. mpi.reduce_scatter_block always scatters along the first
628 // dimension. Hence, if scatter-dim != 0, we need to rearrange the input
629 // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
630 // transposing nRanks to the first dimension.
631
632 LogicalResult
633 matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
634 ConversionPatternRewriter &rewriter) const override {
635 auto gridAxes = adaptor.getGridAxes();
636 int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
637
638 SymbolTableCollection symbolTableCollection;
639 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
640 if (failed(gridOp))
641 return failure();
642
643 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
644 Value rawInput = adaptor.getInput();
645 auto inShapedType = cast<ShapedType>(rawInput.getType());
646 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
647 auto elemType = outType.getElementType();
648 auto inputShape = inShapedType.getShape();
649 auto outputShape = outType.getShape();
650 int64_t inputDimOnAxis = inputShape[scatterDim];
651 int64_t outputDimOnAxis = outputShape[scatterDim];
652
653 for (size_t i = 0; i < outputShape.size(); ++i)
654 if (outputShape[i] != inputShape[i] &&
655 i != static_cast<size_t>(scatterDim))
656 return op.emitError(
657 "Result and input shapes must match along non-scatter axes.");
658 if (outputDimOnAxis == 0)
659 return op.emitError(
660 "Output size along the scatter axis must be non-zero.");
661 if (inputDimOnAxis % outputDimOnAxis != 0)
662 return op.emitError(
663 "Input size along the scatter axis must be an exact "
664 "multiple of the output size along the scatter axis.");
665
667 return op.emitError("Result must be a statically shaped memref in "
668 "contiguous row-major layout.");
669
670 int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
671
672 // Verify that nRanks matches the number of devices along the grid axes.
673 int64_t gridGroupSize =
674 collectiveProcessGroupSize(gridAxes, gridOp->getShape());
675 if (nRanks != gridGroupSize)
676 return op.emitError()
677 << "Expected the scatter factor (" << nRanks
678 << ") to match the number of devices along grid_axes ("
679 << gridGroupSize << ").";
680
681 // Get the right communicator.
682 Value comm = getComm(*gridOp, gridAxes, ib);
683
684 Value mpiInput;
685 if (scatterDim == 0) {
686 // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
687 // Input must be contiguous for MPI.
688 Value input = getAsMemref(rawInput, ib, true);
689 MemRefType inType = cast<MemRefType>(input.getType());
691 return op.emitError("Input must be a statically shaped memref in "
692 "contiguous row-major layout.");
693 mpiInput = input;
694 } else {
695 // For scatter_dim != 0 we rearrange the input so the scatter factor
696 // becomes the first dimension.
697 //
698 // 1. Get a tensor representation of the input (avoid memref->tensor
699 // round-trip if the input is already a tensor).
700 Value tensorInput = rawInput;
701 if (!isa<RankedTensorType>(rawInput.getType())) {
702 auto inTensorType = RankedTensorType::get(inputShape, elemType);
703 tensorInput =
704 bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
705 }
706
707 // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
708 // {d0, ..., nRanks, o_sd, ..., dN}.
709 SmallVector<int64_t> expandedShape;
710 SmallVector<ReassociationIndices> expandReassociation;
711 int64_t expandedIdx = 0;
712 for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
713 if (i == scatterDim) {
714 expandedShape.push_back(nRanks);
715 expandedShape.push_back(outputDimOnAxis);
716 expandReassociation.push_back({expandedIdx, expandedIdx + 1});
717 expandedIdx += 2;
718 } else {
719 expandedShape.push_back(inputShape[i]);
720 expandReassociation.push_back({expandedIdx});
721 expandedIdx += 1;
722 }
723 }
724 auto expandedType = RankedTensorType::get(expandedShape, elemType);
725 tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
726 expandReassociation);
727
728 // 3. Transpose to move nRanks (at position scatterDim) to position 0:
729 // {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
730 SmallVector<int64_t> permutation, transposedShape;
731 permutation.emplace_back(scatterDim);
732 for (int64_t i = 0; i < scatterDim; ++i)
733 permutation.emplace_back(i);
734 for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
735 permutation.emplace_back(i);
736 for (auto p : permutation)
737 transposedShape.emplace_back(expandedShape[p]);
738
739 Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
740 tensorInput =
741 linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
742 ->getResult(0);
743
744 // 4. Materialize as contiguous memref for MPI by copying into a
745 // freshly allocated buffer.
746 auto mpiInType = MemRefType::get(transposedShape, elemType);
747 Value transposedBuf =
748 bufferization::ToBufferOp::create(ib, mpiInType, tensorInput, true);
749 mpiInput = memref::AllocOp::create(ib, mpiInType);
750 linalg::CopyOp::create(ib, transposedBuf, mpiInput);
751 }
752
753 // Allocate output buffer.
754 Value output = memref::AllocOp::create(ib, outType);
755 // Create the MPI ReduceScatter operation.
756 mpi::ReduceScatterBlockOp::create(
757 ib, TypeRange(), mpiInput, output,
758 getMPIReductionOp(adaptor.getReductionAttr()), comm);
759
760 // If the destination is a tensor, cast it to a tensor.
761 if (isa<RankedTensorType>(op.getType()))
762 output =
763 bufferization::ToTensorOp::create(ib, op.getType(), output, true);
764 else if (scatterDim != 0) // Deallocate the temporary input buffer
765 memref::DeallocOp::create(ib, mpiInput);
766 // Notice: If this is called from tensor-world, then we assume an extra pass
767 // will take care of deallocating the intermediate buffers.
768
769 rewriter.replaceOp(op, output);
770 return success();
771 }
772};
773
774struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
775 using CommOpPattern::CommOpPattern;
776
777 // shard.allgather concatenates along a specified gather-axis.
778 // mpi.allgather always concatenates along the first dimension and
779 // there is no MPI operation that allows gathering along an arbitrary axis.
780 // Hence, if gather-axis != 0, we need to permute the output buffer
781 // accordingly.
782
783 LogicalResult
784 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter) const override {
786 SymbolTableCollection symbolTableCollection;
787 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
788 if (failed(gridOp))
789 return failure();
790
791 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
792 Value input = getAsMemref(adaptor.getInput(), ib, true);
793 MemRefType inType = cast<MemRefType>(input.getType());
794 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
795 auto inputShape = inType.getShape();
796 auto outputShape = outType.getShape();
797 int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
798 int64_t inputDimOnAxis = inputShape[gatherAxis];
799 int64_t outputDimOnAxis = outputShape[gatherAxis];
800
801 for (size_t i = 0; i < outputShape.size(); ++i)
802 if (outputShape[i] != inputShape[i] && i != (size_t)gatherAxis)
803 return op.emitError(
804 "Result and input shapes must match along non-gather axes.");
805 if (inputDimOnAxis == 0)
806 return op.emitError("Input size along the gather axis must be non-zero.");
807 if (inputDimOnAxis == 1) {
808 assert(outputDimOnAxis == inputDimOnAxis);
809 rewriter.replaceOp(op, adaptor.getInput());
810 return success();
811 }
812 if (outputDimOnAxis % inputDimOnAxis != 0)
813 return op.emitError("Result size along the gather axis must be an exact "
814 "multiple of the input size along the gather axis.");
815
818 return op.emitError("Input/result must be statically shaped memrefs in "
819 "contiguous row-major layout.");
820
821 // Get the right communicator.
822 Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
823 Value nRanksV =
824 mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
825 nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
826 int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
827 Value nRanksC = arith::ConstantIndexOp::create(ib, nRanks);
828 Value notError =
829 arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
830 cf::AssertOp::create(ib, notError,
831 "Expected number of ranks in the communicator to "
832 "match the output size along the gather axis divided "
833 "by the input size along the gather axis.");
834
835 // mpi.allgather always concatenates along the first dimension, so
836 // get a output buffer of shape {nRanks, dim0, ...}.
837 SmallVector<int64_t> gatherShape;
838 gatherShape.emplace_back(nRanks);
839 gatherShape.append(inputShape.begin(), inputShape.end());
840 auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
841 Value finalOutput = memref::AllocOp::create(ib, gatherType);
842 // Create the MPI AllGather operation.
843 mpi::AllGatherOp::create(ib, TypeRange(), input, finalOutput, comm);
844
845 if (gatherAxis == 0) {
846 // If gather axis == 0, simply collapse the first 2 dims from {nRanks,
847 // dim0, ...} to {nRanks*dim0, ...}.
848 SmallVector<ReassociationIndices> reassociation;
849 reassociation.push_back({0, 1});
850 int64_t numGatherDims = gatherShape.size();
851 for (int64_t i = 2; i < numGatherDims; ++i)
852 reassociation.push_back({i});
853 finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
854 reassociation);
855
856 // If the op's result is a tensor, cast it to a tensor.
857 if (isa<RankedTensorType>(op.getType()))
858 finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
859 finalOutput, true);
860 } else {
861 // 1. Enter tensor-land.
862 auto inType =
863 RankedTensorType::get(gatherShape, outType.getElementType());
864 finalOutput =
865 bufferization::ToTensorOp::create(ib, inType, finalOutput, true);
866
867 // 2. Permute the output buffer from {nRanks, dim0, ..., gatherAxis, ...}
868 // to {dim0, ..., nRanks, dim1,...}.
869 SmallVector<int64_t> outShapePermuted, permutation;
870 for (int i = 1; i <= gatherAxis; ++i) {
871 outShapePermuted.emplace_back(gatherShape[i]);
872 permutation.emplace_back(i);
873 }
874 outShapePermuted.emplace_back(gatherShape[0]);
875 permutation.emplace_back(0);
876 for (size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
877 outShapePermuted.emplace_back(gatherShape[i]);
878 permutation.emplace_back(i);
879 }
880 Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
881 outType.getElementType());
882 finalOutput =
883 linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
884 ->getResult(0);
885
886 // 3. Collapse the output buffer from {dim0, ..., nRanks, gatherAxis, ...}
887 // to {dim0, ..., nRanks*gatherAxis, ...}.
888 SmallVector<ReassociationIndices> reassociation;
889 for (int64_t i = 0; i < gatherAxis; ++i) {
890 reassociation.push_back({i});
891 }
892 reassociation.push_back({gatherAxis, gatherAxis + 1});
893 for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
894 ++i) {
895 reassociation.push_back({i});
896 }
897 auto outTType =
898 RankedTensorType::get(outputShape, outType.getElementType());
899 finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
900 reassociation);
901
902 // 4. Cast back to memref if needed.
903 if (isa<MemRefType>(op.getType()))
904 finalOutput =
905 bufferization::ToBufferOp::create(ib, outType, finalOutput, false);
906 }
907
908 rewriter.replaceOp(op, finalOutput);
909 return success();
910 }
911};
912
913struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
914 using OpConversionPattern::OpConversionPattern;
915
916 LogicalResult
917 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
918 ConversionPatternRewriter &rewriter) const override {
919
920 // The input/output memref is assumed to be in C memory order.
921 // Halos are exchanged as 2 blocks per dimension (one for each side: down
922 // and up). For each haloed dimension `d`, the exchanged blocks are
923 // expressed as multi-dimensional subviews. The subviews include potential
924 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
925 // `dl < d` and for dimension `d` the currently exchanged halo only.
926 // By iterating form higher to lower dimensions this also updates the halos
927 // in the 'corners'.
928 // memref.subview is used to read and write the halo data from and to the
929 // local data. Because subviews and halos can have mixed dynamic and static
930 // shapes, OpFoldResults are used whenever possible.
931
932 auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
933 adaptor.getHaloSizes(), rewriter);
934 if (haloSizes.empty()) {
935 // no halos -> nothing to do
936 rewriter.replaceOp(op, adaptor.getDestination());
937 return success();
938 }
939
940 SymbolTableCollection symbolTableCollection;
941 Location loc = op.getLoc();
942
943 // convert a OpFoldResult into a Value
944 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
945 if (auto value = dyn_cast<Value>(v))
946 return value;
947 return arith::ConstantOp::create(
948 rewriter, loc,
949 rewriter.getIndexAttr(
950 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
951 };
952
953 auto dest = adaptor.getDestination();
954 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
955 Value array = dest;
956 if (isa<RankedTensorType>(array.getType())) {
957 // If the destination is a memref, we need to cast it to a tensor
958 auto mmemrefType = MemRefType::get(
959 dstShape, cast<ShapedType>(array.getType()).getElementType());
960 array =
961 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
962 }
963 auto rank = cast<ShapedType>(array.getType()).getRank();
964 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
965 auto grid = adaptor.getGrid();
966 auto gridOp = getGrid(op, symbolTableCollection);
967 // subviews need Index values
968 for (auto &sz : haloSizes) {
969 if (auto value = dyn_cast<Value>(sz))
970 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
971 value)
972 .getResult();
973 }
974
975 // most of the offset/size/stride data is the same for all dims
976 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
977 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
978 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
979 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
980 // we need the actual shape to compute offsets and sizes
981 for (auto i = 0; i < rank; ++i) {
982 auto s = dstShape[i];
983 if (ShapedType::isDynamic(s))
984 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
985 else
986 shape[i] = rewriter.getIndexAttr(s);
987
988 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
989 ++currHaloDim;
990 // the offsets for lower dim sstarts after their down halo
991 offsets[i] = haloSizes[currHaloDim * 2];
992
993 // prepare shape and offsets of highest dim's halo exchange
994 Value _haloSz = arith::AddIOp::create(
995 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
996 toValue(haloSizes[currHaloDim * 2 + 1]));
997 // the halo shape of lower dims exlude the halos
998 dimSizes[i] =
999 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
1000 .getResult();
1001 } else {
1002 dimSizes[i] = shape[i];
1003 }
1004 }
1005
1006 auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
1007 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
1008 auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
1009 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1010
1011 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
1012 rewriter.getIndexType());
1013 auto myMultiIndex =
1014 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
1015 .getResult();
1016 // traverse all split axes from high to low dim
1017 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
1018 auto splitAxes = opSplitAxes[dim];
1019 if (splitAxes.empty())
1020 continue;
1021 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
1022 // Get the linearized ids of the neighbors (down and up) for the
1023 // given split
1024 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
1025 myMultiIndex, splitAxes)
1026 .getResults();
1027 // MPI operates on i32...
1028 Value neighbourIDs[2] = {
1029 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
1030 tmp[0]),
1031 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
1032 tmp[1])};
1033
1034 auto lowerRecvOffset = rewriter.getIndexAttr(0);
1035 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
1036 auto upperRecvOffset =
1037 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
1038 toValue(haloSizes[currHaloDim * 2 + 1]));
1039 auto upperSendOffset = arith::SubIOp::create(
1040 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
1041
1042 Value commWorld = mpi::CommWorldOp::create(
1043 rewriter, loc, mpi::CommType::get(op->getContext()));
1044
1045 // Make sure we send/recv in a way that does not lead to a dead-lock.
1046 // The current approach is by far not optimal, this should be at least
1047 // be a red-black pattern or using MPI_sendrecv.
1048 // Also, buffers should be re-used.
1049 // Still using temporary contiguous buffers for MPI communication...
1050 // Still yielding a "serialized" communication pattern...
1051 auto genSendRecv = [&](bool upperHalo) {
1052 auto orgOffset = offsets[dim];
1053 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
1054 : haloSizes[currHaloDim * 2];
1055 // Check if we need to send and/or receive
1056 // Processes on the grid borders have only one neighbor
1057 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
1058 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
1059 auto hasFrom = arith::CmpIOp::create(
1060 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
1061 auto hasTo = arith::CmpIOp::create(rewriter, loc,
1062 arith::CmpIPredicate::sge, to, zero);
1063 auto buffer = memref::AllocOp::create(
1064 rewriter, loc, dimSizes,
1065 cast<ShapedType>(array.getType()).getElementType());
1066 // if has neighbor: copy halo data from array to buffer and send
1067 scf::IfOp::create(
1068 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
1069 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
1070 : OpFoldResult(upperSendOffset);
1071 auto subview = memref::SubViewOp::create(
1072 builder, loc, array, offsets, dimSizes, strides);
1073 memref::CopyOp::create(builder, loc, subview, buffer);
1074 mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to,
1075 commWorld);
1076 scf::YieldOp::create(builder, loc);
1077 });
1078 // if has neighbor: receive halo data into buffer and copy to array
1079 scf::IfOp::create(
1080 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
1081 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
1082 : OpFoldResult(lowerRecvOffset);
1083 mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from,
1084 commWorld);
1085 auto subview = memref::SubViewOp::create(
1086 builder, loc, array, offsets, dimSizes, strides);
1087 memref::CopyOp::create(builder, loc, buffer, subview);
1088 scf::YieldOp::create(builder, loc);
1089 });
1090 memref::DeallocOp::create(rewriter, loc, buffer);
1091 offsets[dim] = orgOffset;
1092 };
1093
1094 auto doSendRecv = [&](int upOrDown) {
1095 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
1096 Value haloSz = dyn_cast<Value>(v);
1097 if (!haloSz)
1098 haloSz = arith::ConstantOp::create(
1099 rewriter, loc,
1100 rewriter.getI32IntegerAttr(
1101 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
1102 auto hasSize = arith::CmpIOp::create(
1103 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
1104 scf::IfOp::create(rewriter, loc, hasSize,
1105 [&](OpBuilder &builder, Location loc) {
1106 genSendRecv(upOrDown > 0);
1107 scf::YieldOp::create(builder, loc);
1108 });
1109 };
1110
1111 doSendRecv(0);
1112 doSendRecv(1);
1113
1114 // the shape for lower dims include higher dims' halos
1115 dimSizes[dim] = shape[dim];
1116 // -> the offset for higher dims is always 0
1117 offsets[dim] = rewriter.getIndexAttr(0);
1118 // on to next halo
1119 --currHaloDim;
1120 }
1121
1122 if (isa<MemRefType>(op.getResult().getType())) {
1123 rewriter.replaceOp(op, array);
1124 } else {
1125 assert(isa<RankedTensorType>(op.getResult().getType()));
1126 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
1127 rewriter, loc, op.getResult().getType(), array,
1128 /*restrict=*/true, /*writable=*/true));
1129 }
1130 return success();
1131 }
1132};
1133
1134struct ConvertShardToMPIPass
1135 : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
1136 using Base::Base;
1137
1138 /// Run the dialect converter on the module.
1139 void runOnOperation() override {
1140 auto *ctxt = &getContext();
1141 RewritePatternSet patterns(ctxt);
1142 ConversionTarget target(getContext());
1143
1144 // Define a type converter to convert shard::ShardingType,
1145 // mostly for use in return operations.
1146 TypeConverter typeConverter;
1147 typeConverter.addConversion([](Type type) { return type; });
1148
1149 // convert shard::ShardingType to a tuple of RankedTensorTypes
1150 typeConverter.addConversion(
1151 [](ShardingType type,
1152 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1153 auto i16 = IntegerType::get(type.getContext(), 16);
1154 auto i64 = IntegerType::get(type.getContext(), 64);
1155 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
1156 ShapedType::kDynamic};
1157 results.emplace_back(RankedTensorType::get(shp, i16));
1158 results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
1159 results.emplace_back(RankedTensorType::get(shp, i64));
1160 return success();
1161 });
1162
1163 // To 'extract' components, a UnrealizedConversionCastOp is expected
1164 // to define the input
1165 typeConverter.addTargetMaterialization(
1166 [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
1167 Location loc) {
1168 // Expecting a single input.
1169 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
1170 return SmallVector<Value>();
1171 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
1172 // Expecting an UnrealizedConversionCastOp.
1173 if (!castOp)
1174 return SmallVector<Value>();
1175 // Fill a vector with elements of the tuple/castOp.
1176 SmallVector<Value> results;
1177 for (auto oprnd : castOp.getInputs()) {
1178 if (!isa<RankedTensorType>(oprnd.getType()))
1179 return SmallVector<Value>();
1180 results.emplace_back(oprnd);
1181 }
1182 return results;
1183 });
1184
1185 // No shard dialect should left after conversion...
1186 target.addIllegalDialect<shard::ShardDialect>();
1187 // ...except the global GridOp. GridShapeOp which will get folded later.
1188 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
1189 // Allow all the stuff that our patterns will convert to
1190 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
1191 arith::ArithDialect, tensor::TensorDialect,
1192 bufferization::BufferizationDialect,
1193 linalg::LinalgDialect, memref::MemRefDialect,
1194 affine::AffineDialect, cf::ControlFlowDialect>();
1195 // Make sure the function signature, calls etc. are legal
1196 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1197 return typeConverter.isSignatureLegal(op.getFunctionType());
1198 });
1199 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
1200 [&](Operation *op) { return typeConverter.isLegal(op); });
1201
1202 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
1203 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
1204 ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
1205 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
1206 SymbolTableCollection stc;
1209
1210 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1211 patterns, typeConverter);
1212 populateCallOpTypeConversionPattern(patterns, typeConverter);
1213 populateReturnOpTypeConversionPattern(patterns, typeConverter);
1214
1215 (void)applyPartialConversion(getOperation(), target, std::move(patterns));
1216
1217 // Folding patterns cannot be mixed with conversion patterns -> extra pass.
1218 patterns.clear();
1219 SymbolTableCollection symbolTableCollection;
1220 mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
1221 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1222 }
1223};
1224
1225} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
Definition ShardOps.cpp:214
IntegerType getI32Type()
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:55
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition Location.h:86
This class helps build Operations.
Definition Builders.h:209
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Definition Simplify.cpp:164
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
int16_t GridAxis
Definition ShardOps.h:27
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition ShardOps.cpp:77
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
Definition ShardOps.h:156
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Definition ShardOps.h:131
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...