42#define DEBUG_TYPE "shard-to-mpi"
45#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
46#include "mlir/Conversion/Passes.h.inc"
60 auto dyn = dynamics.begin();
61 Type i64 =
b.getI64Type();
64 assert((i64 == type ||
b.getIndexType() == type) &&
65 "expected an i64 or an intex type");
66 for (
auto s : statics) {
67 if (s == ShapedType::kDynamic) {
68 values.emplace_back(*(dyn++));
70 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
71 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
81 int n = dimensions.size();
84 for (
int i = n - 1; i >= 0; --i) {
85 multiIndex[i] = arith::RemSIOp::create(
b, loc, linearIndex, dimensions[i]);
87 linearIndex = arith::DivSIOp::create(
b, loc, linearIndex, dimensions[i]);
100 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
101 Value off = arith::MulIOp::create(
b, loc, multiIndex[i], stride);
102 linearIndex = arith::AddIOp::create(
b, loc, linearIndex, off);
103 stride = arith::MulIOp::create(
b, loc, stride, dimensions[i]);
110struct ConvertGetShardingOp :
public OpConversionPattern<GetShardingOp> {
111 using OpConversionPattern::OpConversionPattern;
114 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter)
const override {
119 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
123 rewriter.replaceOp(op, shardingOp.getResult());
131struct ConvertShardingOp :
public OpConversionPattern<ShardingOp> {
132 using OpConversionPattern::OpConversionPattern;
135 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter)
const override {
137 auto splitAxes = op.getSplitAxes().getAxes();
138 int64_t maxNAxes = 0;
139 for (
auto axes : splitAxes)
140 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
145 Location loc = op.getLoc();
146 auto i16 = rewriter.getI16Type();
147 auto i64 = rewriter.getI64Type();
148 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
150 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
151 auto attr = IntegerAttr::get(i16, -1);
152 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
154 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
158 std::array<int64_t, 2> strides = {1, 1};
161 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
162 int64_t size = axes.size();
165 std::array<int64_t, 2> offs = {(int64_t)i, 0};
166 std::array<int64_t, 2> sizes = {1, size};
167 auto tensorType = RankedTensorType::get({size}, i16);
169 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
170 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
171 resSplitAxes, empty, empty,
172 empty, offs, sizes, strides);
177 SmallVector<Value> haloSizes =
179 adaptor.getDynamicHaloSizes());
180 auto type = RankedTensorType::get({nSplits, 2}, i64);
183 ? tensor::EmptyOp::create(rewriter, loc,
184 std::array<int64_t, 2>{0, 0}, i64)
186 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
195 if (adaptor.getStaticShardedDimsOffsets().empty()) {
196 resOffsets = tensor::EmptyOp::create(rewriter, loc,
197 std::array<int64_t, 2>{0, 0}, i64);
199 SymbolTableCollection symbolTableCollection;
200 auto gridOp =
getGrid(op, symbolTableCollection);
201 int64_t maxSplitSize = 0;
202 for (
auto axes : splitAxes) {
205 assert(splitSize != ShapedType::kDynamic);
206 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
208 assert(maxSplitSize);
211 resOffsets = tensor::EmptyOp::create(
212 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
213 Value zero = arith::ConstantOp::create(
214 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
216 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
217 SmallVector<Value> offsets =
219 adaptor.getDynamicShardedDimsOffsets());
221 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
224 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
226 ArrayRef<Value> values(&offsets[curr], splitSize);
227 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
228 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
229 std::array<int64_t, 2> sizes = {1, splitSize};
230 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
231 resOffsets, empty, empty,
232 empty, offs, sizes, strides);
238 SmallVector<Type> resTypes;
239 if (
failed(getTypeConverter()->convertType(op.getResult().getType(),
244 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
246 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
247 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
249 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
250 op, TupleType::get(op.getContext(), resTypes),
251 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
257class ConvertProcessLinearIndexOp
258 :
public OpConversionPattern<ProcessLinearIndexOp> {
261 using OpConversionPattern::OpConversionPattern;
264 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const override {
267 Location loc = op.getLoc();
270 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
271 auto rank = mpi::CommRankOp::create(
273 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
276 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
282struct ConvertNeighborsLinearIndicesOp
283 :
public OpConversionPattern<NeighborsLinearIndicesOp> {
284 using OpConversionPattern::OpConversionPattern;
287 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
294 auto axes = adaptor.getSplitAxes();
296 if (axes.size() != 1)
299 Location loc = op.getLoc();
300 SymbolTableCollection symbolTableCollection;
301 auto gridOp =
getGrid(op, symbolTableCollection);
302 auto mIdx = adaptor.getDevice();
303 auto orgIdx = mIdx[axes[0]];
304 SmallVector<Value> dims;
306 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
307 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
309 Value dimSz = dims[axes[0]];
313 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
315 auto down = scf::IfOp::create(
316 rewriter, loc, atBorder,
317 [&](OpBuilder &builder, Location loc) {
318 scf::YieldOp::create(builder, loc, minus1);
320 [&](OpBuilder &builder, Location loc) {
321 SmallVector<Value> tmp = mIdx;
323 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
325 scf::YieldOp::create(builder, loc,
326 multiToLinearIndex(loc, rewriter, tmp, dims));
328 atBorder = arith::CmpIOp::create(
329 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
330 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
331 auto up = scf::IfOp::create(
332 rewriter, loc, atBorder,
333 [&](OpBuilder &builder, Location loc) {
334 scf::YieldOp::create(builder, loc, minus1);
336 [&](OpBuilder &builder, Location loc) {
337 SmallVector<Value> tmp = mIdx;
339 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
340 scf::YieldOp::create(builder, loc,
341 multiToLinearIndex(loc, rewriter, tmp, dims));
343 rewriter.replaceOp(op,
ValueRange{down.getResult(0), up.getResult(0)});
348struct ConvertShardShapeOp :
public OpConversionPattern<ShardShapeOp> {
349 using OpConversionPattern::OpConversionPattern;
352 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
356 return op->emitError()
357 <<
"Expected ShardingOp as defining op for sharding"
358 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
370 Location loc = op.getLoc();
371 Type index = rewriter.getIndexType();
377 SmallVector<Value> dynDims, dynDevice;
378 for (
auto dim : adaptor.getDimsDynamic()) {
380 dynDims.emplace_back(llvm::getSingleElement(dim));
383 for (
auto device : adaptor.getDeviceDynamic()) {
384 dynDevice.emplace_back(llvm::getSingleElement(device));
389 SmallVector<Value> shape =
391 SmallVector<Value> multiIdx =
395 SymbolTableCollection symbolTableCollection;
396 auto gridOp =
getGrid(sharding, symbolTableCollection);
398 if (ShapedType::isDynamicShape(gridOp.getShape()))
401 auto splitAxes = sharding.getSplitAxes().getAxes();
407 Value shardedDimsOffs;
410 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
411 sharding.getDynamicShardedDimsOffsets(), index);
413 shardedDimsOffs = tensor::FromElementsOp::create(
414 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
424 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
426 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
430 for (
auto [i, dim] : llvm::enumerate(shape)) {
432 if (i < splitAxes.size() && !splitAxes[i].empty()) {
433 auto axes = splitAxes[i];
436 Value posVal = arith::ConstantOp::create(rewriter, loc,
437 rewriter.getIndexAttr(pos));
439 Value idx = multiIdx[axes[0]];
442 if (shardedDimsOffs) {
445 if (axes.size() > 1) {
446 return op->emitError() <<
"Only single axis sharding is "
447 <<
"supported for each dimension.";
449 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
452 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
453 idx = arith::AddIOp::create(rewriter, loc, idx, one);
455 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
456 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
459 Value numShardsVal = arith::ConstantOp::create(
460 rewriter, loc, rewriter.getIndexAttr(numShards));
465 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
466 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
467 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
468 auto cond = arith::CmpIOp::create(
469 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
470 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
471 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
474 pos += numShards + 1;
486static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
487 auto *ctx = kind.getContext();
489 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
492 switch (kind.getValue()) {
493 case ReductionKind::Sum:
495 case ReductionKind::Product:
497 case ReductionKind::Min:
499 case ReductionKind::Max:
501 case ReductionKind::BitwiseAnd:
503 case ReductionKind::BitwiseOr:
505 case ReductionKind::BitwiseXor:
508 llvm_unreachable(
"Unknown/unsupported reduction kind");
512template <
typename CommOp>
513struct CommOpPattern :
public OpConversionPattern<CommOp> {
514 using OpConversionPattern<CommOp>::OpConversionPattern;
516 MemRefType getMemrefType(ShapedType tensorType)
const {
517 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
520 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder)
const {
523 if (isa<RankedTensorType>(itype)) {
524 auto memrefType = getMemrefType(cast<ShapedType>(itype));
525 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
527 assert(isa<MemRefType>(itype) &&
528 "expected input to be of MemRefType or TensorType");
533 FailureOr<GridOp> checkGrid(CommOp op,
534 SymbolTableCollection &symbolTableCollection,
535 bool allowDynamic =
false)
const {
538 return op->emitError() <<
"Missing grid symbol.";
539 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
540 return op->emitError() <<
"Dynamic grid shape not supported.";
548 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
549 ImplicitLocOpBuilder &iBuilder)
const {
550 size_t gridDims = gridOp.getShape().size();
551 auto commType = mpi::CommType::get(gridOp->getContext());
552 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
554 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
558 SmallVector<GridAxis> otherAxes;
559 for (
GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
560 if (!llvm::is_contained(gridAxes, i))
561 otherAxes.emplace_back(i);
564 SmallVector<Type> indexResultTypes(otherAxes.size(),
569 color = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), color);
573 key = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), key);
576 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
581struct ConvertAllReduceOp :
public CommOpPattern<AllReduceOp> {
582 using CommOpPattern::CommOpPattern;
585 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter)
const override {
587 SymbolTableCollection symbolTableCollection;
588 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
591 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
592 Value input = getAsMemref(adaptor.getInput(), iBuilder);
593 MemRefType inType = cast<MemRefType>(input.
getType());
596 "Expected static shaped memref in contiguous row-major layout.");
597 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
600 "Expected static shaped memref in contiguous row-major layout.");
603 Value buffer = memref::AllocOp::create(iBuilder, outType);
604 linalg::CopyOp::create(iBuilder, input, buffer);
606 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
608 mpi::AllReduceOp::create(iBuilder,
TypeRange(), buffer, buffer,
609 getMPIReductionOp(adaptor.getReductionAttr()),
613 if (isa<RankedTensorType>(op.getType()))
614 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
616 rewriter.replaceOp(op, buffer);
621struct ConvertAllGatherOp :
public CommOpPattern<AllGatherOp> {
622 using CommOpPattern::CommOpPattern;
631 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter)
const override {
633 SymbolTableCollection symbolTableCollection;
634 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
638 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
639 Value input = getAsMemref(adaptor.getInput(), ib);
640 MemRefType inType = cast<MemRefType>(input.
getType());
641 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
642 auto inputShape = inType.getShape();
643 auto outputShape = outType.getShape();
644 int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
645 int64_t inputDimOnAxis = inputShape[gatherAxis];
646 int64_t outputDimOnAxis = outputShape[gatherAxis];
648 for (
size_t i = 0; i < outputShape.size(); ++i)
649 if (outputShape[i] != inputShape[i] && i != (
size_t)gatherAxis)
651 "Result and input shapes must match along non-gather axes.");
652 if (inputDimOnAxis == 0)
653 return op.emitError(
"Input size along the gather axis must be non-zero.");
654 if (inputDimOnAxis == 1) {
655 assert(outputDimOnAxis == inputDimOnAxis);
656 rewriter.replaceOp(op, adaptor.getInput());
659 if (outputDimOnAxis % inputDimOnAxis != 0)
660 return op.emitError(
"Result size along the gather axis must be an exact "
661 "multiple of the input size along the gather axis.");
665 return op.emitError(
"Input/result must be statically shaped memrefs in "
666 "contiguous row-major layout.");
669 Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
671 mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
672 nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
673 int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
676 arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
677 cf::AssertOp::create(ib, notError,
678 "Expected number of ranks in the communicator to "
679 "match the output size along the gather axis divided "
680 "by the input size along the gather axis.");
684 SmallVector<int64_t> gatherShape;
685 gatherShape.emplace_back(nRanks);
686 gatherShape.append(inputShape.begin(), inputShape.end());
687 auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
688 Value finalOutput = memref::AllocOp::create(ib, gatherType);
690 mpi::AllGatherOp::create(ib,
TypeRange(), input, finalOutput, comm);
692 if (gatherAxis == 0) {
695 SmallVector<ReassociationIndices> reassociation;
696 reassociation.push_back({0, 1});
697 int64_t numGatherDims = gatherShape.size();
698 for (int64_t i = 2; i < numGatherDims; ++i)
699 reassociation.push_back({i});
700 finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
704 if (isa<RankedTensorType>(op.getType()))
705 finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
710 RankedTensorType::get(gatherShape, outType.getElementType());
712 bufferization::ToTensorOp::create(ib, inType, finalOutput,
true);
716 SmallVector<int64_t> outShapePermuted, permutation;
717 for (
int i = 1; i <= gatherAxis; ++i) {
718 outShapePermuted.emplace_back(gatherShape[i]);
719 permutation.emplace_back(i);
721 outShapePermuted.emplace_back(gatherShape[0]);
722 permutation.emplace_back(0);
723 for (
size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
724 outShapePermuted.emplace_back(gatherShape[i]);
725 permutation.emplace_back(i);
727 Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
728 outType.getElementType());
730 linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
735 SmallVector<ReassociationIndices> reassociation;
736 for (int64_t i = 0; i < gatherAxis; ++i) {
737 reassociation.push_back({i});
739 reassociation.push_back({gatherAxis, gatherAxis + 1});
740 for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
742 reassociation.push_back({i});
745 RankedTensorType::get(outputShape, outType.getElementType());
746 finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
750 if (isa<MemRefType>(op.getType()))
752 bufferization::ToBufferOp::create(ib, outType, finalOutput);
755 rewriter.replaceOp(op, finalOutput);
760struct ConvertUpdateHaloOp :
public OpConversionPattern<UpdateHaloOp> {
761 using OpConversionPattern::OpConversionPattern;
764 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
765 ConversionPatternRewriter &rewriter)
const override {
780 adaptor.getHaloSizes(), rewriter);
781 if (haloSizes.empty()) {
783 rewriter.replaceOp(op, adaptor.getDestination());
787 SymbolTableCollection symbolTableCollection;
788 Location loc = op.getLoc();
791 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
792 if (
auto value = dyn_cast<Value>(v))
794 return arith::ConstantOp::create(
796 rewriter.getIndexAttr(
797 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
800 auto dest = adaptor.getDestination();
801 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
803 if (isa<RankedTensorType>(array.
getType())) {
805 auto mmemrefType = MemRefType::get(
806 dstShape, cast<ShapedType>(array.
getType()).getElementType());
808 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
810 auto rank = cast<ShapedType>(array.
getType()).getRank();
811 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
812 auto grid = adaptor.getGrid();
813 auto gridOp =
getGrid(op, symbolTableCollection);
815 for (
auto &sz : haloSizes) {
816 if (
auto value = dyn_cast<Value>(sz))
817 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
823 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
824 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
825 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
826 auto currHaloDim = -1;
828 for (
auto i = 0; i < rank; ++i) {
829 auto s = dstShape[i];
830 if (ShapedType::isDynamic(s))
831 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
833 shape[i] = rewriter.getIndexAttr(s);
835 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
838 offsets[i] = haloSizes[currHaloDim * 2];
841 Value _haloSz = arith::AddIOp::create(
842 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
843 toValue(haloSizes[currHaloDim * 2 + 1]));
846 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
849 dimSizes[i] = shape[i];
853 auto tagAttr = rewriter.getI32IntegerAttr(91);
854 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
855 auto zeroAttr = rewriter.getI32IntegerAttr(0);
856 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
858 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
859 rewriter.getIndexType());
861 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
864 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
865 auto splitAxes = opSplitAxes[dim];
866 if (splitAxes.empty())
868 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
871 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
872 myMultiIndex, splitAxes)
875 Value neighbourIDs[2] = {
876 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
878 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
881 auto lowerRecvOffset = rewriter.getIndexAttr(0);
882 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
883 auto upperRecvOffset =
884 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
885 toValue(haloSizes[currHaloDim * 2 + 1]));
886 auto upperSendOffset = arith::SubIOp::create(
887 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
889 Value commWorld = mpi::CommWorldOp::create(
890 rewriter, loc, mpi::CommType::get(op->getContext()));
898 auto genSendRecv = [&](
bool upperHalo) {
899 auto orgOffset = offsets[dim];
900 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
901 : haloSizes[currHaloDim * 2];
904 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
905 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
906 auto hasFrom = arith::CmpIOp::create(
907 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
908 auto hasTo = arith::CmpIOp::create(rewriter, loc,
909 arith::CmpIPredicate::sge, to, zero);
910 auto buffer = memref::AllocOp::create(
911 rewriter, loc, dimSizes,
912 cast<ShapedType>(array.
getType()).getElementType());
915 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
916 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
917 : OpFoldResult(upperSendOffset);
918 auto subview = memref::SubViewOp::create(
919 builder, loc, array, offsets, dimSizes, strides);
920 memref::CopyOp::create(builder, loc, subview, buffer);
921 mpi::SendOp::create(builder, loc,
TypeRange{}, buffer, tag, to,
923 scf::YieldOp::create(builder, loc);
927 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
928 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
929 : OpFoldResult(lowerRecvOffset);
930 mpi::RecvOp::create(builder, loc,
TypeRange{}, buffer, tag, from,
932 auto subview = memref::SubViewOp::create(
933 builder, loc, array, offsets, dimSizes, strides);
934 memref::CopyOp::create(builder, loc, buffer, subview);
935 scf::YieldOp::create(builder, loc);
937 memref::DeallocOp::create(rewriter, loc, buffer);
938 offsets[dim] = orgOffset;
941 auto doSendRecv = [&](
int upOrDown) {
942 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
943 Value haloSz = dyn_cast<Value>(v);
945 haloSz = arith::ConstantOp::create(
947 rewriter.getI32IntegerAttr(
948 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
949 auto hasSize = arith::CmpIOp::create(
950 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
951 scf::IfOp::create(rewriter, loc, hasSize,
952 [&](OpBuilder &builder, Location loc) {
953 genSendRecv(upOrDown > 0);
954 scf::YieldOp::create(builder, loc);
962 dimSizes[dim] = shape[dim];
964 offsets[dim] = rewriter.getIndexAttr(0);
969 if (isa<MemRefType>(op.getResult().getType())) {
970 rewriter.replaceOp(op, array);
972 assert(isa<RankedTensorType>(op.getResult().getType()));
973 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
974 rewriter, loc, op.getResult().getType(), array,
981struct ConvertShardToMPIPass
982 :
public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
986 void runOnOperation()
override {
993 TypeConverter typeConverter;
994 typeConverter.addConversion([](Type type) {
return type; });
997 typeConverter.addConversion(
998 [](ShardingType type,
999 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1000 auto i16 = IntegerType::get(type.getContext(), 16);
1001 auto i64 = IntegerType::get(type.getContext(), 64);
1002 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
1003 ShapedType::kDynamic};
1004 results.emplace_back(RankedTensorType::get(shp, i16));
1005 results.emplace_back(RankedTensorType::get(shp, i64));
1006 results.emplace_back(RankedTensorType::get(shp, i64));
1012 typeConverter.addTargetMaterialization(
1016 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
1017 return SmallVector<Value>();
1018 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
1021 return SmallVector<Value>();
1023 SmallVector<Value> results;
1024 for (
auto oprnd : castOp.getInputs()) {
1025 if (!isa<RankedTensorType>(oprnd.getType()))
1026 return SmallVector<Value>();
1027 results.emplace_back(oprnd);
1033 target.addIllegalDialect<shard::ShardDialect>();
1035 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
1037 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
1038 arith::ArithDialect, tensor::TensorDialect,
1039 bufferization::BufferizationDialect,
1040 linalg::LinalgDialect, memref::MemRefDialect,
1041 affine::AffineDialect, cf::ControlFlowDialect>();
1043 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1044 return typeConverter.isSignatureLegal(op.getFunctionType());
1046 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
1047 [&](Operation *op) {
return typeConverter.isLegal(op); });
1049 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
1050 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
1051 ConvertAllGatherOp, ConvertAllReduceOp,
1052 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
1053 SymbolTableCollection stc;
1057 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1062 (void)applyPartialConversion(getOperation(),
target, std::move(
patterns));
1066 SymbolTableCollection symbolTableCollection;
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns