41#define DEBUG_TYPE "shard-to-mpi"
44#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
45#include "mlir/Conversion/Passes.h.inc"
59 auto dyn = dynamics.begin();
60 Type i64 =
b.getI64Type();
63 assert((i64 == type ||
b.getIndexType() == type) &&
64 "expected an i64 or an intex type");
65 for (
auto s : statics) {
66 if (s == ShapedType::kDynamic) {
67 values.emplace_back(*(dyn++));
69 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
70 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
80 int n = dimensions.size();
83 for (
int i = n - 1; i >= 0; --i) {
84 multiIndex[i] = arith::RemSIOp::create(
b, loc, linearIndex, dimensions[i]);
86 linearIndex = arith::DivSIOp::create(
b, loc, linearIndex, dimensions[i]);
99 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
100 Value off = arith::MulIOp::create(
b, loc, multiIndex[i], stride);
101 linearIndex = arith::AddIOp::create(
b, loc, linearIndex, off);
102 stride = arith::MulIOp::create(
b, loc, stride, dimensions[i]);
109struct ConvertGetShardingOp :
public OpConversionPattern<GetShardingOp> {
110 using OpConversionPattern::OpConversionPattern;
113 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter)
const override {
118 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
122 rewriter.replaceOp(op, shardingOp.getResult());
130struct ConvertShardingOp :
public OpConversionPattern<ShardingOp> {
131 using OpConversionPattern::OpConversionPattern;
134 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter)
const override {
136 auto splitAxes = op.getSplitAxes().getAxes();
137 int64_t maxNAxes = 0;
138 for (
auto axes : splitAxes)
139 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
144 Location loc = op.getLoc();
145 auto i16 = rewriter.getI16Type();
146 auto i64 = rewriter.getI64Type();
147 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
149 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
150 auto attr = IntegerAttr::get(i16, -1);
151 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
153 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
157 std::array<int64_t, 2> strides = {1, 1};
160 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
161 int64_t size = axes.size();
164 std::array<int64_t, 2> offs = {(int64_t)i, 0};
165 std::array<int64_t, 2> sizes = {1, size};
166 auto tensorType = RankedTensorType::get({size}, i16);
168 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
169 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
170 resSplitAxes, empty, empty,
171 empty, offs, sizes, strides);
176 SmallVector<Value> haloSizes =
178 adaptor.getDynamicHaloSizes());
179 auto type = RankedTensorType::get({nSplits, 2}, i64);
182 ? tensor::EmptyOp::create(rewriter, loc,
183 std::array<int64_t, 2>{0, 0}, i64)
185 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
194 if (adaptor.getStaticShardedDimsOffsets().empty()) {
195 resOffsets = tensor::EmptyOp::create(rewriter, loc,
196 std::array<int64_t, 2>{0, 0}, i64);
198 SymbolTableCollection symbolTableCollection;
199 auto gridOp =
getGrid(op, symbolTableCollection);
200 int64_t maxSplitSize = 0;
201 for (
auto axes : splitAxes) {
204 assert(splitSize != ShapedType::kDynamic);
205 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
207 assert(maxSplitSize);
210 resOffsets = tensor::EmptyOp::create(
211 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
212 Value zero = arith::ConstantOp::create(
213 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
215 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
216 SmallVector<Value> offsets =
218 adaptor.getDynamicShardedDimsOffsets());
220 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
223 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
225 ArrayRef<Value> values(&offsets[curr], splitSize);
226 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
227 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
228 std::array<int64_t, 2> sizes = {1, splitSize};
229 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
230 resOffsets, empty, empty,
231 empty, offs, sizes, strides);
237 SmallVector<Type> resTypes;
238 if (
failed(getTypeConverter()->convertType(op.getResult().getType(),
243 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
245 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
246 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
248 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
249 op, TupleType::get(op.getContext(), resTypes),
250 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
256class ConvertProcessLinearIndexOp
257 :
public OpConversionPattern<ProcessLinearIndexOp> {
260 using OpConversionPattern::OpConversionPattern;
263 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
264 ConversionPatternRewriter &rewriter)
const override {
266 Location loc = op.getLoc();
269 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
270 auto rank = mpi::CommRankOp::create(
272 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
275 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
281struct ConvertNeighborsLinearIndicesOp
282 :
public OpConversionPattern<NeighborsLinearIndicesOp> {
283 using OpConversionPattern::OpConversionPattern;
286 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override {
293 auto axes = adaptor.getSplitAxes();
295 if (axes.size() != 1)
298 Location loc = op.getLoc();
299 SymbolTableCollection symbolTableCollection;
300 auto gridOp =
getGrid(op, symbolTableCollection);
301 auto mIdx = adaptor.getDevice();
302 auto orgIdx = mIdx[axes[0]];
303 SmallVector<Value> dims;
305 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
306 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
308 Value dimSz = dims[axes[0]];
312 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
314 auto down = scf::IfOp::create(
315 rewriter, loc, atBorder,
316 [&](OpBuilder &builder, Location loc) {
317 scf::YieldOp::create(builder, loc, minus1);
319 [&](OpBuilder &builder, Location loc) {
320 SmallVector<Value> tmp = mIdx;
322 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
324 scf::YieldOp::create(builder, loc,
325 multiToLinearIndex(loc, rewriter, tmp, dims));
327 atBorder = arith::CmpIOp::create(
328 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
329 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
330 auto up = scf::IfOp::create(
331 rewriter, loc, atBorder,
332 [&](OpBuilder &builder, Location loc) {
333 scf::YieldOp::create(builder, loc, minus1);
335 [&](OpBuilder &builder, Location loc) {
336 SmallVector<Value> tmp = mIdx;
338 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
339 scf::YieldOp::create(builder, loc,
340 multiToLinearIndex(loc, rewriter, tmp, dims));
342 rewriter.replaceOp(op,
ValueRange{down.getResult(0), up.getResult(0)});
347struct ConvertShardShapeOp :
public OpConversionPattern<ShardShapeOp> {
348 using OpConversionPattern::OpConversionPattern;
351 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter)
const override {
353 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
355 return op->emitError()
356 <<
"Expected ShardingOp as defining op for sharding"
357 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
369 Location loc = op.getLoc();
370 Type index = rewriter.getIndexType();
376 SmallVector<Value> dynDims, dynDevice;
377 for (
auto dim : adaptor.getDimsDynamic()) {
379 dynDims.emplace_back(llvm::getSingleElement(dim));
382 for (
auto device : adaptor.getDeviceDynamic()) {
383 dynDevice.emplace_back(llvm::getSingleElement(device));
388 SmallVector<Value> shape =
390 SmallVector<Value> multiIdx =
394 SymbolTableCollection symbolTableCollection;
395 auto gridOp =
getGrid(sharding, symbolTableCollection);
397 if (ShapedType::isDynamicShape(gridOp.getShape()))
400 auto splitAxes = sharding.getSplitAxes().getAxes();
406 Value shardedDimsOffs;
409 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
410 sharding.getDynamicShardedDimsOffsets(), index);
412 shardedDimsOffs = tensor::FromElementsOp::create(
413 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
423 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
425 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
429 for (
auto [i, dim] : llvm::enumerate(shape)) {
431 if (i < splitAxes.size() && !splitAxes[i].empty()) {
432 auto axes = splitAxes[i];
435 Value posVal = arith::ConstantOp::create(rewriter, loc,
436 rewriter.getIndexAttr(pos));
438 Value idx = multiIdx[axes[0]];
441 if (shardedDimsOffs) {
444 if (axes.size() > 1) {
445 return op->emitError() <<
"Only single axis sharding is "
446 <<
"supported for each dimension.";
448 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
451 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
452 idx = arith::AddIOp::create(rewriter, loc, idx, one);
454 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
455 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
458 Value numShardsVal = arith::ConstantOp::create(
459 rewriter, loc, rewriter.getIndexAttr(numShards));
464 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
465 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
466 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
467 auto cond = arith::CmpIOp::create(
468 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
469 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
470 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
473 pos += numShards + 1;
485static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
486 auto *ctx = kind.getContext();
488 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
491 switch (kind.getValue()) {
492 case ReductionKind::Sum:
494 case ReductionKind::Product:
496 case ReductionKind::Min:
498 case ReductionKind::Max:
500 case ReductionKind::BitwiseAnd:
502 case ReductionKind::BitwiseOr:
504 case ReductionKind::BitwiseXor:
507 llvm_unreachable(
"Unknown/unsupported reduction kind");
511template <
typename CommOp>
512struct CommOpPattern :
public OpConversionPattern<CommOp> {
513 using OpConversionPattern<CommOp>::OpConversionPattern;
515 MemRefType getMemrefType(ShapedType tensorType)
const {
516 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
519 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder)
const {
522 if (isa<RankedTensorType>(itype)) {
523 auto memrefType = getMemrefType(cast<ShapedType>(itype));
524 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
526 assert(isa<MemRefType>(itype) &&
527 "expected input to be of MemRefType or TensorType");
532 FailureOr<GridOp> checkGrid(CommOp op,
533 SymbolTableCollection &symbolTableCollection,
534 bool allowDynamic =
false)
const {
537 return op->emitError() <<
"Missing grid symbol.";
538 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
539 return op->emitError() <<
"Dynamic grid shape not supported.";
547 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
548 ImplicitLocOpBuilder &iBuilder)
const {
549 size_t gridDims = gridOp.getShape().size();
550 auto commType = mpi::CommType::get(gridOp->getContext());
551 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
553 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
557 SmallVector<GridAxis> otherAxes;
558 for (
GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
559 if (!llvm::is_contained(gridAxes, i))
560 otherAxes.emplace_back(i);
563 SmallVector<Type> indexResultTypes(otherAxes.size(),
568 color = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), color);
572 key = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), key);
575 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
580struct ConvertAllReduceOp :
public CommOpPattern<AllReduceOp> {
581 using CommOpPattern::CommOpPattern;
584 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter)
const override {
586 SymbolTableCollection symbolTableCollection;
587 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
590 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
591 Value input = getAsMemref(adaptor.getInput(), iBuilder);
592 MemRefType inType = cast<MemRefType>(input.
getType());
595 "Expected static shaped memref in contiguous row-major layout.");
596 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
599 "Expected static shaped memref in contiguous row-major layout.");
602 Value buffer = memref::AllocOp::create(iBuilder, outType);
603 linalg::CopyOp::create(iBuilder, input, buffer);
605 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
607 mpi::AllReduceOp::create(iBuilder,
TypeRange(), buffer, buffer,
608 getMPIReductionOp(adaptor.getReductionAttr()),
612 if (isa<RankedTensorType>(op.getType()))
613 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
615 rewriter.replaceOp(op, buffer);
620struct ConvertAllGatherOp :
public CommOpPattern<AllGatherOp> {
621 using CommOpPattern::CommOpPattern;
624 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
625 ConversionPatternRewriter &rewriter)
const override {
626 SymbolTableCollection symbolTableCollection;
627 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
630 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
631 Value input = getAsMemref(adaptor.getInput(), iBuilder);
632 MemRefType inType = cast<MemRefType>(input.
getType());
635 "Expected static shaped memref in contiguous row-major layout.");
636 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
639 "Expected static shaped memref in contiguous row-major layout.");
642 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
644 Value output = memref::AllocOp::create(iBuilder, outType);
646 mpi::AllGatherOp::create(iBuilder,
TypeRange(), input, output, comm);
649 if (isa<RankedTensorType>(op.getType()))
650 output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
652 rewriter.replaceOp(op, output);
657struct ConvertUpdateHaloOp :
public OpConversionPattern<UpdateHaloOp> {
658 using OpConversionPattern::OpConversionPattern;
661 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
662 ConversionPatternRewriter &rewriter)
const override {
677 adaptor.getHaloSizes(), rewriter);
678 if (haloSizes.empty()) {
680 rewriter.replaceOp(op, adaptor.getDestination());
684 SymbolTableCollection symbolTableCollection;
685 Location loc = op.getLoc();
688 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
689 if (
auto value = dyn_cast<Value>(v))
691 return arith::ConstantOp::create(
693 rewriter.getIndexAttr(
694 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
697 auto dest = adaptor.getDestination();
698 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
700 if (isa<RankedTensorType>(array.
getType())) {
702 auto mmemrefType = MemRefType::get(
703 dstShape, cast<ShapedType>(array.
getType()).getElementType());
705 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
707 auto rank = cast<ShapedType>(array.
getType()).getRank();
708 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
709 auto grid = adaptor.getGrid();
710 auto gridOp =
getGrid(op, symbolTableCollection);
712 for (
auto &sz : haloSizes) {
713 if (
auto value = dyn_cast<Value>(sz))
714 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
720 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
721 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
722 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
723 auto currHaloDim = -1;
725 for (
auto i = 0; i < rank; ++i) {
726 auto s = dstShape[i];
727 if (ShapedType::isDynamic(s))
728 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
730 shape[i] = rewriter.getIndexAttr(s);
732 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
735 offsets[i] = haloSizes[currHaloDim * 2];
738 Value _haloSz = arith::AddIOp::create(
739 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
740 toValue(haloSizes[currHaloDim * 2 + 1]));
743 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
746 dimSizes[i] = shape[i];
750 auto tagAttr = rewriter.getI32IntegerAttr(91);
751 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
752 auto zeroAttr = rewriter.getI32IntegerAttr(0);
753 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
755 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
756 rewriter.getIndexType());
758 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
761 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
762 auto splitAxes = opSplitAxes[dim];
763 if (splitAxes.empty())
765 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
768 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
769 myMultiIndex, splitAxes)
772 Value neighbourIDs[2] = {
773 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
775 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
778 auto lowerRecvOffset = rewriter.getIndexAttr(0);
779 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
780 auto upperRecvOffset =
781 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
782 toValue(haloSizes[currHaloDim * 2 + 1]));
783 auto upperSendOffset = arith::SubIOp::create(
784 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
786 Value commWorld = mpi::CommWorldOp::create(
787 rewriter, loc, mpi::CommType::get(op->getContext()));
795 auto genSendRecv = [&](
bool upperHalo) {
796 auto orgOffset = offsets[dim];
797 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
798 : haloSizes[currHaloDim * 2];
801 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
802 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
803 auto hasFrom = arith::CmpIOp::create(
804 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
805 auto hasTo = arith::CmpIOp::create(rewriter, loc,
806 arith::CmpIPredicate::sge, to, zero);
807 auto buffer = memref::AllocOp::create(
808 rewriter, loc, dimSizes,
809 cast<ShapedType>(array.
getType()).getElementType());
812 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
813 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
814 : OpFoldResult(upperSendOffset);
815 auto subview = memref::SubViewOp::create(
816 builder, loc, array, offsets, dimSizes, strides);
817 memref::CopyOp::create(builder, loc, subview, buffer);
818 mpi::SendOp::create(builder, loc,
TypeRange{}, buffer, tag, to,
820 scf::YieldOp::create(builder, loc);
824 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
825 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
826 : OpFoldResult(lowerRecvOffset);
827 mpi::RecvOp::create(builder, loc,
TypeRange{}, buffer, tag, from,
829 auto subview = memref::SubViewOp::create(
830 builder, loc, array, offsets, dimSizes, strides);
831 memref::CopyOp::create(builder, loc, buffer, subview);
832 scf::YieldOp::create(builder, loc);
834 memref::DeallocOp::create(rewriter, loc, buffer);
835 offsets[dim] = orgOffset;
838 auto doSendRecv = [&](
int upOrDown) {
839 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
840 Value haloSz = dyn_cast<Value>(v);
842 haloSz = arith::ConstantOp::create(
844 rewriter.getI32IntegerAttr(
845 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
846 auto hasSize = arith::CmpIOp::create(
847 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
848 scf::IfOp::create(rewriter, loc, hasSize,
849 [&](OpBuilder &builder, Location loc) {
850 genSendRecv(upOrDown > 0);
851 scf::YieldOp::create(builder, loc);
859 dimSizes[dim] = shape[dim];
861 offsets[dim] = rewriter.getIndexAttr(0);
866 if (isa<MemRefType>(op.getResult().getType())) {
867 rewriter.replaceOp(op, array);
869 assert(isa<RankedTensorType>(op.getResult().getType()));
870 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
871 rewriter, loc, op.getResult().getType(), array,
878struct ConvertShardToMPIPass
883 void runOnOperation()
override {
890 TypeConverter typeConverter;
891 typeConverter.addConversion([](Type type) {
return type; });
894 typeConverter.addConversion(
895 [](ShardingType type,
896 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
897 auto i16 = IntegerType::get(type.getContext(), 16);
898 auto i64 = IntegerType::get(type.getContext(), 64);
899 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
900 ShapedType::kDynamic};
901 results.emplace_back(RankedTensorType::get(shp, i16));
902 results.emplace_back(RankedTensorType::get(shp, i64));
903 results.emplace_back(RankedTensorType::get(shp, i64));
909 typeConverter.addTargetMaterialization(
913 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
914 return SmallVector<Value>();
915 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
918 return SmallVector<Value>();
920 SmallVector<Value> results;
921 for (
auto oprnd : castOp.getInputs()) {
922 if (!isa<RankedTensorType>(oprnd.getType()))
923 return SmallVector<Value>();
924 results.emplace_back(oprnd);
930 target.addIllegalDialect<shard::ShardDialect>();
932 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
934 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
935 arith::ArithDialect, tensor::TensorDialect,
936 bufferization::BufferizationDialect,
937 linalg::LinalgDialect, memref::MemRefDialect,
938 affine::AffineDialect, cf::ControlFlowDialect>();
940 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
941 return typeConverter.isSignatureLegal(op.getFunctionType());
943 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
944 [&](Operation *op) {
return typeConverter.isLegal(op); });
946 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
947 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
948 ConvertAllGatherOp, ConvertAllReduceOp,
949 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
950 SymbolTableCollection stc;
954 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
959 (void)applyPartialConversion(getOperation(),
target, std::move(
patterns));
963 SymbolTableCollection symbolTableCollection;
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
const FrozenRewritePatternSet & patterns