38#define DEBUG_TYPE "shard-to-mpi"
41#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
42#include "mlir/Conversion/Passes.h.inc"
56 auto dyn = dynamics.begin();
57 Type i64 =
b.getI64Type();
60 assert((i64 == type ||
b.getIndexType() == type) &&
61 "expected an i64 or an intex type");
62 for (
auto s : statics) {
63 if (s == ShapedType::kDynamic) {
64 values.emplace_back(*(dyn++));
66 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
67 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
77 int n = dimensions.size();
80 for (
int i = n - 1; i >= 0; --i) {
81 multiIndex[i] = arith::RemSIOp::create(
b, loc, linearIndex, dimensions[i]);
83 linearIndex = arith::DivSIOp::create(
b, loc, linearIndex, dimensions[i]);
96 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
97 Value off = arith::MulIOp::create(
b, loc, multiIndex[i], stride);
98 linearIndex = arith::AddIOp::create(
b, loc, linearIndex, off);
99 stride = arith::MulIOp::create(
b, loc, stride, dimensions[i]);
106struct ConvertGetShardingOp :
public OpConversionPattern<GetShardingOp> {
107 using OpConversionPattern::OpConversionPattern;
110 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const override {
115 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
119 rewriter.replaceOp(op, shardingOp.getResult());
127struct ConvertShardingOp :
public OpConversionPattern<ShardingOp> {
128 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
133 auto splitAxes = op.getSplitAxes().getAxes();
134 int64_t maxNAxes = 0;
135 for (
auto axes : splitAxes)
136 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
141 Location loc = op.getLoc();
142 auto i16 = rewriter.getI16Type();
143 auto i64 = rewriter.getI64Type();
144 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
146 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
147 auto attr = IntegerAttr::get(i16, -1);
148 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
150 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
154 std::array<int64_t, 2> strides = {1, 1};
157 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
158 int64_t size = axes.size();
161 std::array<int64_t, 2> offs = {(int64_t)i, 0};
162 std::array<int64_t, 2> sizes = {1, size};
163 auto tensorType = RankedTensorType::get({size}, i16);
165 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
166 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
167 resSplitAxes, empty, empty,
168 empty, offs, sizes, strides);
173 SmallVector<Value> haloSizes =
175 adaptor.getDynamicHaloSizes());
176 auto type = RankedTensorType::get({nSplits, 2}, i64);
179 ? tensor::EmptyOp::create(rewriter, loc,
180 std::array<int64_t, 2>{0, 0}, i64)
182 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
191 if (adaptor.getStaticShardedDimsOffsets().empty()) {
192 resOffsets = tensor::EmptyOp::create(rewriter, loc,
193 std::array<int64_t, 2>{0, 0}, i64);
195 SymbolTableCollection symbolTableCollection;
196 auto gridOp =
getGrid(op, symbolTableCollection);
197 int64_t maxSplitSize = 0;
198 for (
auto axes : splitAxes) {
201 assert(splitSize != ShapedType::kDynamic);
202 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
204 assert(maxSplitSize);
207 resOffsets = tensor::EmptyOp::create(
208 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
209 Value zero = arith::ConstantOp::create(
210 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
212 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
213 SmallVector<Value> offsets =
215 adaptor.getDynamicShardedDimsOffsets());
217 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
220 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
222 ArrayRef<Value> values(&offsets[curr], splitSize);
223 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
224 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
225 std::array<int64_t, 2> sizes = {1, splitSize};
226 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
227 resOffsets, empty, empty,
228 empty, offs, sizes, strides);
234 SmallVector<Type> resTypes;
235 if (
failed(getTypeConverter()->convertType(op.getResult().getType(),
240 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
242 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
243 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
245 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
246 op, TupleType::get(op.getContext(), resTypes),
247 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
253struct ConvertProcessMultiIndexOp
254 :
public OpConversionPattern<ProcessMultiIndexOp> {
255 using OpConversionPattern::OpConversionPattern;
258 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
259 ConversionPatternRewriter &rewriter)
const override {
263 SymbolTableCollection symbolTableCollection;
264 Location loc = op.getLoc();
265 auto gridOp =
getGrid(op, symbolTableCollection);
267 if (ShapedType::isDynamicShape(gridOp.getShape()))
270 SmallVector<Value> dims;
272 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
273 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
275 Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
276 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
279 auto axes = adaptor.getAxes();
281 SmallVector<Value> subIndex;
282 for (
auto axis : axes) {
283 subIndex.emplace_back(mIdx[axis]);
285 mIdx = std::move(subIndex);
288 rewriter.replaceOp(op, mIdx);
293class ConvertProcessLinearIndexOp
294 :
public OpConversionPattern<ProcessLinearIndexOp> {
297 using OpConversionPattern::OpConversionPattern;
300 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter)
const override {
303 Location loc = op.getLoc();
306 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
307 auto rank = mpi::CommRankOp::create(
309 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
312 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
318struct ConvertNeighborsLinearIndicesOp
319 :
public OpConversionPattern<NeighborsLinearIndicesOp> {
320 using OpConversionPattern::OpConversionPattern;
323 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter)
const override {
330 auto axes = adaptor.getSplitAxes();
332 if (axes.size() != 1)
335 Location loc = op.getLoc();
336 SymbolTableCollection symbolTableCollection;
337 auto gridOp =
getGrid(op, symbolTableCollection);
338 auto mIdx = adaptor.getDevice();
339 auto orgIdx = mIdx[axes[0]];
340 SmallVector<Value> dims;
342 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
343 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
345 Value dimSz = dims[axes[0]];
349 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
351 auto down = scf::IfOp::create(
352 rewriter, loc, atBorder,
353 [&](OpBuilder &builder, Location loc) {
354 scf::YieldOp::create(builder, loc, minus1);
356 [&](OpBuilder &builder, Location loc) {
357 SmallVector<Value> tmp = mIdx;
359 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
361 scf::YieldOp::create(builder, loc,
362 multiToLinearIndex(loc, rewriter, tmp, dims));
364 atBorder = arith::CmpIOp::create(
365 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
366 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
367 auto up = scf::IfOp::create(
368 rewriter, loc, atBorder,
369 [&](OpBuilder &builder, Location loc) {
370 scf::YieldOp::create(builder, loc, minus1);
372 [&](OpBuilder &builder, Location loc) {
373 SmallVector<Value> tmp = mIdx;
375 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
376 scf::YieldOp::create(builder, loc,
377 multiToLinearIndex(loc, rewriter, tmp, dims));
379 rewriter.replaceOp(op,
ValueRange{down.getResult(0), up.getResult(0)});
384struct ConvertShardShapeOp :
public OpConversionPattern<ShardShapeOp> {
385 using OpConversionPattern::OpConversionPattern;
388 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
389 ConversionPatternRewriter &rewriter)
const override {
390 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
392 return op->emitError()
393 <<
"Expected ShardingOp as defining op for sharding"
394 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
406 Location loc = op.getLoc();
407 Type index = rewriter.getIndexType();
413 SmallVector<Value> dynDims, dynDevice;
414 for (
auto dim : adaptor.getDimsDynamic()) {
416 dynDims.emplace_back(llvm::getSingleElement(dim));
419 for (
auto device : adaptor.getDeviceDynamic()) {
420 dynDevice.emplace_back(llvm::getSingleElement(device));
425 SmallVector<Value> shape =
427 SmallVector<Value> multiIdx =
431 SymbolTableCollection symbolTableCollection;
432 auto gridOp =
getGrid(sharding, symbolTableCollection);
434 if (ShapedType::isDynamicShape(gridOp.getShape()))
437 auto splitAxes = sharding.getSplitAxes().getAxes();
443 Value shardedDimsOffs;
446 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
447 sharding.getDynamicShardedDimsOffsets(), index);
449 shardedDimsOffs = tensor::FromElementsOp::create(
450 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
460 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
462 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
466 for (
auto [i, dim] : llvm::enumerate(shape)) {
468 if (i < splitAxes.size() && !splitAxes[i].empty()) {
469 auto axes = splitAxes[i];
472 Value posVal = arith::ConstantOp::create(rewriter, loc,
473 rewriter.getIndexAttr(pos));
475 Value idx = multiIdx[axes[0]];
478 if (shardedDimsOffs) {
481 if (axes.size() > 1) {
482 return op->emitError() <<
"Only single axis sharding is "
483 <<
"supported for each dimension.";
485 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
488 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
489 idx = arith::AddIOp::create(rewriter, loc, idx, one);
491 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
492 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
495 Value numShardsVal = arith::ConstantOp::create(
496 rewriter, loc, rewriter.getIndexAttr(numShards));
501 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
502 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
503 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
504 auto cond = arith::CmpIOp::create(
505 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
506 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
507 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
510 pos += numShards + 1;
522static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523 auto *ctx = kind.getContext();
525 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
528 switch (kind.getValue()) {
529 case ReductionKind::Sum:
531 case ReductionKind::Product:
533 case ReductionKind::Min:
535 case ReductionKind::Max:
537 case ReductionKind::BitwiseAnd:
539 case ReductionKind::BitwiseOr:
541 case ReductionKind::BitwiseXor:
544 llvm_unreachable(
"Unknown/unsupported reduction kind");
548struct ConvertAllReduceOp :
public OpConversionPattern<AllReduceOp> {
549 using OpConversionPattern::OpConversionPattern;
552 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter)
const override {
554 SymbolTableCollection symbolTableCollection;
555 auto grid = adaptor.getGrid();
556 mlir::shard::GridOp gridOp =
getGrid(op, symbolTableCollection);
558 return op->emitError() <<
"No grid found for AllReduceOp";
559 if (ShapedType::isDynamicShape(gridOp.getShape()))
560 return op->emitError()
561 <<
"Dynamic grid shape not supported in AllReduceOp";
563 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
564 Value input = adaptor.getInput();
565 auto inputShape = cast<ShapedType>(input.
getType()).getShape();
568 if (isa<RankedTensorType>(input.
getType())) {
569 auto memrefType = MemRefType::get(
570 inputShape, cast<ShapedType>(input.
getType()).getElementType());
571 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
573 MemRefType inType = cast<MemRefType>(input.
getType());
576 SmallVector<OpFoldResult> shape(inType.getRank());
577 for (
auto i = 0; i < inType.getRank(); ++i) {
578 auto s = inputShape[i];
579 if (ShapedType::isDynamic(s))
580 shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
582 shape[i] = iBuilder.getIndexAttr(s);
586 Value buffer = memref::AllocOp::create(
587 iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
588 linalg::CopyOp::create(iBuilder, input, buffer);
594 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
595 iBuilder.getIndexType());
596 SmallVector<Value> myMultiIndex =
597 ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
600 SmallVector<Value> multiKey(myMultiIndex.size(), zero);
602 auto redAxes = adaptor.getGridAxes();
603 for (
auto axis : redAxes) {
604 multiKey[axis] = myMultiIndex[axis];
605 myMultiIndex[axis] = zero;
610 color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
612 key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
615 auto commType = mpi::CommType::get(op->getContext());
616 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
618 mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
621 Value buffer1d = buffer;
623 if (inType.getRank() > 1) {
625 std::iota(reassociation.begin(), reassociation.end(), 0);
626 buffer1d = memref::CollapseShapeOp::create(
627 iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
631 mpi::AllReduceOp::create(iBuilder,
TypeRange(), buffer1d, buffer1d,
632 getMPIReductionOp(adaptor.getReductionAttr()),
636 if (isa<RankedTensorType>(op.getType()))
637 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
640 rewriter.replaceOp(op, buffer);
645struct ConvertUpdateHaloOp :
public OpConversionPattern<UpdateHaloOp> {
646 using OpConversionPattern::OpConversionPattern;
649 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter)
const override {
665 adaptor.getHaloSizes(), rewriter);
666 if (haloSizes.empty()) {
668 rewriter.replaceOp(op, adaptor.getDestination());
672 SymbolTableCollection symbolTableCollection;
673 Location loc = op.getLoc();
676 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
677 if (
auto value = dyn_cast<Value>(v))
679 return arith::ConstantOp::create(
681 rewriter.getIndexAttr(
682 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
685 auto dest = adaptor.getDestination();
686 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
688 if (isa<RankedTensorType>(array.
getType())) {
690 auto mmemrefType = MemRefType::get(
691 dstShape, cast<ShapedType>(array.
getType()).getElementType());
693 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
695 auto rank = cast<ShapedType>(array.
getType()).getRank();
696 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
697 auto grid = adaptor.getGrid();
698 auto gridOp =
getGrid(op, symbolTableCollection);
700 for (
auto &sz : haloSizes) {
701 if (
auto value = dyn_cast<Value>(sz))
702 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
708 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
709 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
710 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
711 auto currHaloDim = -1;
713 for (
auto i = 0; i < rank; ++i) {
714 auto s = dstShape[i];
715 if (ShapedType::isDynamic(s))
716 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
718 shape[i] = rewriter.getIndexAttr(s);
720 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
723 offsets[i] = haloSizes[currHaloDim * 2];
726 Value _haloSz = arith::AddIOp::create(
727 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
728 toValue(haloSizes[currHaloDim * 2 + 1]));
731 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
734 dimSizes[i] = shape[i];
738 auto tagAttr = rewriter.getI32IntegerAttr(91);
739 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
740 auto zeroAttr = rewriter.getI32IntegerAttr(0);
741 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
743 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
744 rewriter.getIndexType());
746 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
749 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
750 auto splitAxes = opSplitAxes[dim];
751 if (splitAxes.empty())
753 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
756 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
757 myMultiIndex, splitAxes)
760 Value neighbourIDs[2] = {
761 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
763 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
766 auto lowerRecvOffset = rewriter.getIndexAttr(0);
767 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
768 auto upperRecvOffset =
769 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
770 toValue(haloSizes[currHaloDim * 2 + 1]));
771 auto upperSendOffset = arith::SubIOp::create(
772 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
774 Value commWorld = mpi::CommWorldOp::create(
775 rewriter, loc, mpi::CommType::get(op->getContext()));
783 auto genSendRecv = [&](
bool upperHalo) {
784 auto orgOffset = offsets[dim];
785 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
786 : haloSizes[currHaloDim * 2];
789 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
790 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
791 auto hasFrom = arith::CmpIOp::create(
792 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
793 auto hasTo = arith::CmpIOp::create(rewriter, loc,
794 arith::CmpIPredicate::sge, to, zero);
795 auto buffer = memref::AllocOp::create(
796 rewriter, loc, dimSizes,
797 cast<ShapedType>(array.
getType()).getElementType());
800 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
801 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
802 : OpFoldResult(upperSendOffset);
803 auto subview = memref::SubViewOp::create(
804 builder, loc, array, offsets, dimSizes, strides);
805 memref::CopyOp::create(builder, loc, subview, buffer);
806 mpi::SendOp::create(builder, loc,
TypeRange{}, buffer, tag, to,
808 scf::YieldOp::create(builder, loc);
812 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
813 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
814 : OpFoldResult(lowerRecvOffset);
815 mpi::RecvOp::create(builder, loc,
TypeRange{}, buffer, tag, from,
817 auto subview = memref::SubViewOp::create(
818 builder, loc, array, offsets, dimSizes, strides);
819 memref::CopyOp::create(builder, loc, buffer, subview);
820 scf::YieldOp::create(builder, loc);
822 memref::DeallocOp::create(rewriter, loc, buffer);
823 offsets[dim] = orgOffset;
826 auto doSendRecv = [&](
int upOrDown) {
827 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
828 Value haloSz = dyn_cast<Value>(v);
830 haloSz = arith::ConstantOp::create(
832 rewriter.getI32IntegerAttr(
833 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
834 auto hasSize = arith::CmpIOp::create(
835 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
836 scf::IfOp::create(rewriter, loc, hasSize,
837 [&](OpBuilder &builder, Location loc) {
838 genSendRecv(upOrDown > 0);
839 scf::YieldOp::create(builder, loc);
847 dimSizes[dim] = shape[dim];
849 offsets[dim] = rewriter.getIndexAttr(0);
854 if (isa<MemRefType>(op.getResult().getType())) {
855 rewriter.replaceOp(op, array);
857 assert(isa<RankedTensorType>(op.getResult().getType()));
858 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
859 rewriter, loc, op.getResult().getType(), array,
866struct ConvertShardToMPIPass
871 void runOnOperation()
override {
878 TypeConverter typeConverter;
879 typeConverter.addConversion([](Type type) {
return type; });
882 typeConverter.addConversion(
883 [](ShardingType type,
884 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
885 auto i16 = IntegerType::get(type.getContext(), 16);
886 auto i64 = IntegerType::get(type.getContext(), 64);
887 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
888 ShapedType::kDynamic};
889 results.emplace_back(RankedTensorType::get(shp, i16));
890 results.emplace_back(RankedTensorType::get(shp, i64));
891 results.emplace_back(RankedTensorType::get(shp, i64));
897 typeConverter.addTargetMaterialization(
901 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
902 return SmallVector<Value>();
903 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
906 return SmallVector<Value>();
908 SmallVector<Value> results;
909 for (
auto oprnd : castOp.getInputs()) {
910 if (!isa<RankedTensorType>(oprnd.getType()))
911 return SmallVector<Value>();
912 results.emplace_back(oprnd);
918 target.addIllegalDialect<shard::ShardDialect>();
920 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
923 BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
924 tensor::TensorDialect, bufferization::BufferizationDialect,
925 linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
927 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
928 return typeConverter.isSignatureLegal(op.getFunctionType());
930 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
931 [&](Operation *op) {
return typeConverter.isLegal(op); });
933 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
934 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
935 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
936 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
938 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
943 (void)applyPartialConversion(getOperation(),
target, std::move(
patterns));
947 SymbolTableCollection symbolTableCollection;
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t, 2 > ReassociationIndices