42#define DEBUG_TYPE "shard-to-mpi"
45#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
46#include "mlir/Conversion/Passes.h.inc"
60 auto dyn = dynamics.begin();
61 Type i64 =
b.getI64Type();
64 assert((i64 == type ||
b.getIndexType() == type) &&
65 "expected an i64 or an intex type");
66 for (
auto s : statics) {
67 if (s == ShapedType::kDynamic) {
68 values.emplace_back(*(dyn++));
70 TypedAttr val = type == i64 ?
b.getI64IntegerAttr(s) :
b.getIndexAttr(s);
71 values.emplace_back(arith::ConstantOp::create(
b, loc, type, val));
81 int n = dimensions.size();
84 for (
int i = n - 1; i >= 0; --i) {
85 multiIndex[i] = arith::RemSIOp::create(
b, loc, linearIndex, dimensions[i]);
87 linearIndex = arith::DivSIOp::create(
b, loc, linearIndex, dimensions[i]);
100 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
101 Value off = arith::MulIOp::create(
b, loc, multiIndex[i], stride);
102 linearIndex = arith::AddIOp::create(
b, loc, linearIndex, off);
103 stride = arith::MulIOp::create(
b, loc, stride, dimensions[i]);
110struct ConvertGetShardingOp :
public OpConversionPattern<GetShardingOp> {
111 using OpConversionPattern::OpConversionPattern;
114 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter)
const override {
119 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
123 rewriter.replaceOp(op, shardingOp.getResult());
131struct ConvertShardingOp :
public OpConversionPattern<ShardingOp> {
132 using OpConversionPattern::OpConversionPattern;
135 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter)
const override {
137 auto splitAxes = op.getSplitAxes().getAxes();
138 int64_t maxNAxes = 0;
139 for (
auto axes : splitAxes)
140 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
145 Location loc = op.getLoc();
146 auto i16 = rewriter.getI16Type();
147 auto i64 = rewriter.getI64Type();
148 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
150 Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16);
151 auto attr = IntegerAttr::get(i16, -1);
152 Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr);
154 linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes)
158 std::array<int64_t, 2> strides = {1, 1};
161 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
162 int64_t size = axes.size();
165 std::array<int64_t, 2> offs = {(int64_t)i, 0};
166 std::array<int64_t, 2> sizes = {1, size};
167 auto tensorType = RankedTensorType::get({size}, i16);
169 auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs);
170 resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals,
171 resSplitAxes, empty, empty,
172 empty, offs, sizes, strides);
177 SmallVector<Value> haloSizes =
179 adaptor.getDynamicHaloSizes());
180 auto type = RankedTensorType::get({nSplits, 2}, i64);
183 ? tensor::EmptyOp::create(rewriter, loc,
184 std::array<int64_t, 2>{0, 0}, i64)
186 : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
195 if (adaptor.getStaticShardedDimsOffsets().empty()) {
196 resOffsets = tensor::EmptyOp::create(rewriter, loc,
197 std::array<int64_t, 2>{0, 0}, i64);
199 SymbolTableCollection symbolTableCollection;
200 auto gridOp =
getGrid(op, symbolTableCollection);
201 int64_t maxSplitSize = 0;
202 for (
auto axes : splitAxes) {
205 assert(splitSize != ShapedType::kDynamic);
206 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
208 assert(maxSplitSize);
211 resOffsets = tensor::EmptyOp::create(
212 rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
213 Value zero = arith::ConstantOp::create(
214 rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
216 linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0);
217 SmallVector<Value> offsets =
219 adaptor.getDynamicShardedDimsOffsets());
221 for (
auto [i, axes] : llvm::enumerate(splitAxes)) {
224 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
226 ArrayRef<Value> values(&offsets[curr], splitSize);
227 Value vals = tensor::FromElementsOp::create(rewriter, loc, values);
228 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
229 std::array<int64_t, 2> sizes = {1, splitSize};
230 resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals,
231 resOffsets, empty, empty,
232 empty, offs, sizes, strides);
238 SmallVector<Type> resTypes;
239 if (
failed(getTypeConverter()->convertType(op.getResult().getType(),
244 tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes);
246 tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes);
247 resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets);
249 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
250 op, TupleType::get(op.getContext(), resTypes),
251 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
257class ConvertProcessLinearIndexOp
258 :
public OpConversionPattern<ProcessLinearIndexOp> {
261 using OpConversionPattern::OpConversionPattern;
264 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const override {
267 Location loc = op.getLoc();
270 mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
271 auto rank = mpi::CommRankOp::create(
273 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
276 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
282struct ConvertNeighborsLinearIndicesOp
283 :
public OpConversionPattern<NeighborsLinearIndicesOp> {
284 using OpConversionPattern::OpConversionPattern;
287 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
294 auto axes = adaptor.getSplitAxes();
296 if (axes.size() != 1)
299 Location loc = op.getLoc();
300 SymbolTableCollection symbolTableCollection;
301 auto gridOp =
getGrid(op, symbolTableCollection);
302 auto mIdx = adaptor.getDevice();
303 auto orgIdx = mIdx[axes[0]];
304 SmallVector<Value> dims;
306 gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
307 return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
309 Value dimSz = dims[axes[0]];
313 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx,
315 auto down = scf::IfOp::create(
316 rewriter, loc, atBorder,
317 [&](OpBuilder &builder, Location loc) {
318 scf::YieldOp::create(builder, loc, minus1);
320 [&](OpBuilder &builder, Location loc) {
321 SmallVector<Value> tmp = mIdx;
323 arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one)
325 scf::YieldOp::create(builder, loc,
326 multiToLinearIndex(loc, rewriter, tmp, dims));
328 atBorder = arith::CmpIOp::create(
329 rewriter, loc, arith::CmpIPredicate::sge, orgIdx,
330 arith::SubIOp::create(rewriter, loc, dimSz, one).getResult());
331 auto up = scf::IfOp::create(
332 rewriter, loc, atBorder,
333 [&](OpBuilder &builder, Location loc) {
334 scf::YieldOp::create(builder, loc, minus1);
336 [&](OpBuilder &builder, Location loc) {
337 SmallVector<Value> tmp = mIdx;
339 arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one);
340 scf::YieldOp::create(builder, loc,
341 multiToLinearIndex(loc, rewriter, tmp, dims));
343 rewriter.replaceOp(op,
ValueRange{down.getResult(0), up.getResult(0)});
348struct ConvertShardShapeOp :
public OpConversionPattern<ShardShapeOp> {
349 using OpConversionPattern::OpConversionPattern;
352 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
356 return op->emitError()
357 <<
"Expected ShardingOp as defining op for sharding"
358 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
370 Location loc = op.getLoc();
371 Type index = rewriter.getIndexType();
377 SmallVector<Value> dynDims, dynDevice;
378 for (
auto dim : adaptor.getDimsDynamic()) {
380 dynDims.emplace_back(llvm::getSingleElement(dim));
383 for (
auto device : adaptor.getDeviceDynamic()) {
384 dynDevice.emplace_back(llvm::getSingleElement(device));
389 SmallVector<Value> shape =
391 SmallVector<Value> multiIdx =
395 SymbolTableCollection symbolTableCollection;
396 auto gridOp =
getGrid(sharding, symbolTableCollection);
398 if (ShapedType::isDynamicShape(gridOp.getShape()))
401 auto splitAxes = sharding.getSplitAxes().getAxes();
407 Value shardedDimsOffs;
410 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
411 sharding.getDynamicShardedDimsOffsets(), index);
413 shardedDimsOffs = tensor::FromElementsOp::create(
414 rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
424 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index));
426 arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index));
430 for (
auto [i, dim] : llvm::enumerate(shape)) {
432 if (i < splitAxes.size() && !splitAxes[i].empty()) {
433 auto axes = splitAxes[i];
436 Value posVal = arith::ConstantOp::create(rewriter, loc,
437 rewriter.getIndexAttr(pos));
439 Value idx = multiIdx[axes[0]];
442 if (shardedDimsOffs) {
445 if (axes.size() > 1) {
446 return op->emitError() <<
"Only single axis sharding is "
447 <<
"supported for each dimension.";
449 idx = arith::AddIOp::create(rewriter, loc, posVal, idx);
452 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
453 idx = arith::AddIOp::create(rewriter, loc, idx, one);
455 tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx);
456 Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off);
459 Value numShardsVal = arith::ConstantOp::create(
460 rewriter, loc, rewriter.getIndexAttr(numShards));
465 Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal);
466 Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal);
467 sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1);
468 auto cond = arith::CmpIOp::create(
469 rewriter, loc, arith::CmpIPredicate::sge, idx, sz1);
470 Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero);
471 sz = arith::AddIOp::create(rewriter, loc, sz, odd);
474 pos += numShards + 1;
486static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
487 auto *ctx = kind.getContext();
489 return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp);
492 switch (kind.getValue()) {
493 case ReductionKind::Sum:
495 case ReductionKind::Product:
497 case ReductionKind::Min:
499 case ReductionKind::Max:
501 case ReductionKind::BitwiseAnd:
503 case ReductionKind::BitwiseOr:
505 case ReductionKind::BitwiseXor:
508 llvm_unreachable(
"Unknown/unsupported reduction kind");
512template <
typename CommOp>
513struct CommOpPattern :
public OpConversionPattern<CommOp> {
514 using OpConversionPattern<CommOp>::OpConversionPattern;
516 MemRefType getMemrefType(ShapedType tensorType)
const {
517 return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
520 Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder,
521 bool readOnly)
const {
524 if (isa<RankedTensorType>(itype)) {
525 auto memrefType = getMemrefType(cast<ShapedType>(itype));
526 input = bufferization::ToBufferOp::create(iBuilder, memrefType, input,
529 assert(isa<MemRefType>(itype) &&
530 "expected input to be of MemRefType or TensorType");
535 FailureOr<GridOp> checkGrid(CommOp op,
536 SymbolTableCollection &symbolTableCollection,
537 bool allowDynamic =
false)
const {
540 return op->emitError() <<
"Missing grid symbol.";
541 if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
542 return op->emitError() <<
"Dynamic grid shape not supported.";
550 Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
551 ImplicitLocOpBuilder &iBuilder)
const {
552 size_t gridDims = gridOp.getShape().size();
553 auto commType = mpi::CommType::get(gridOp->getContext());
554 Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
556 if (gridAxes.empty() || gridAxes.size() >= gridDims) {
560 SmallVector<GridAxis> otherAxes;
561 for (
GridAxis i = 0; i < static_cast<GridAxis>(gridDims); ++i) {
562 if (!llvm::is_contained(gridAxes, i))
563 otherAxes.emplace_back(i);
566 SmallVector<Type> indexResultTypes(otherAxes.size(),
571 color = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), color);
575 key = arith::IndexCastOp::create(iBuilder, iBuilder.
getI32Type(), key);
578 return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
583struct ConvertAllReduceOp :
public CommOpPattern<AllReduceOp> {
584 using CommOpPattern::CommOpPattern;
587 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter)
const override {
589 SymbolTableCollection symbolTableCollection;
590 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
593 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
594 Value input = getAsMemref(adaptor.getInput(), iBuilder,
true);
595 MemRefType inType = cast<MemRefType>(input.
getType());
598 "Expected static shaped memref in contiguous row-major layout.");
599 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
602 "Expected static shaped memref in contiguous row-major layout.");
605 Value buffer = memref::AllocOp::create(iBuilder, outType);
606 linalg::CopyOp::create(iBuilder, input, buffer);
608 Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
610 mpi::AllReduceOp::create(iBuilder,
TypeRange(), buffer, buffer,
611 getMPIReductionOp(adaptor.getReductionAttr()),
615 if (isa<RankedTensorType>(op.getType()))
616 buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
618 rewriter.replaceOp(op, buffer);
623struct ConvertReduceScatterOp :
public CommOpPattern<ReduceScatterOp> {
624 using CommOpPattern::CommOpPattern;
633 matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
634 ConversionPatternRewriter &rewriter)
const override {
635 auto gridAxes = adaptor.getGridAxes();
636 int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
638 SymbolTableCollection symbolTableCollection;
639 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
643 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
644 Value rawInput = adaptor.getInput();
645 auto inShapedType = cast<ShapedType>(rawInput.
getType());
646 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
647 auto elemType = outType.getElementType();
648 auto inputShape = inShapedType.getShape();
649 auto outputShape = outType.getShape();
650 int64_t inputDimOnAxis = inputShape[scatterDim];
651 int64_t outputDimOnAxis = outputShape[scatterDim];
653 for (
size_t i = 0; i < outputShape.size(); ++i)
654 if (outputShape[i] != inputShape[i] &&
655 i !=
static_cast<size_t>(scatterDim))
657 "Result and input shapes must match along non-scatter axes.");
658 if (outputDimOnAxis == 0)
660 "Output size along the scatter axis must be non-zero.");
661 if (inputDimOnAxis % outputDimOnAxis != 0)
663 "Input size along the scatter axis must be an exact "
664 "multiple of the output size along the scatter axis.");
667 return op.emitError(
"Result must be a statically shaped memref in "
668 "contiguous row-major layout.");
670 int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
673 int64_t gridGroupSize =
675 if (nRanks != gridGroupSize)
676 return op.emitError()
677 <<
"Expected the scatter factor (" << nRanks
678 <<
") to match the number of devices along grid_axes ("
679 << gridGroupSize <<
").";
682 Value comm = getComm(*gridOp, gridAxes, ib);
685 if (scatterDim == 0) {
688 Value input = getAsMemref(rawInput, ib,
true);
689 MemRefType inType = cast<MemRefType>(input.
getType());
691 return op.emitError(
"Input must be a statically shaped memref in "
692 "contiguous row-major layout.");
700 Value tensorInput = rawInput;
701 if (!isa<RankedTensorType>(rawInput.
getType())) {
702 auto inTensorType = RankedTensorType::get(inputShape, elemType);
704 bufferization::ToTensorOp::create(ib, inTensorType, rawInput,
true);
709 SmallVector<int64_t> expandedShape;
710 SmallVector<ReassociationIndices> expandReassociation;
711 int64_t expandedIdx = 0;
712 for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
713 if (i == scatterDim) {
714 expandedShape.push_back(nRanks);
715 expandedShape.push_back(outputDimOnAxis);
716 expandReassociation.push_back({expandedIdx, expandedIdx + 1});
719 expandedShape.push_back(inputShape[i]);
720 expandReassociation.push_back({expandedIdx});
724 auto expandedType = RankedTensorType::get(expandedShape, elemType);
725 tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
726 expandReassociation);
730 SmallVector<int64_t> permutation, transposedShape;
731 permutation.emplace_back(scatterDim);
732 for (int64_t i = 0; i < scatterDim; ++i)
733 permutation.emplace_back(i);
734 for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
735 permutation.emplace_back(i);
736 for (
auto p : permutation)
737 transposedShape.emplace_back(expandedShape[p]);
739 Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
741 linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
746 auto mpiInType = MemRefType::get(transposedShape, elemType);
747 Value transposedBuf =
748 bufferization::ToBufferOp::create(ib, mpiInType, tensorInput,
true);
749 mpiInput = memref::AllocOp::create(ib, mpiInType);
750 linalg::CopyOp::create(ib, transposedBuf, mpiInput);
754 Value output = memref::AllocOp::create(ib, outType);
756 mpi::ReduceScatterBlockOp::create(
758 getMPIReductionOp(adaptor.getReductionAttr()), comm);
761 if (isa<RankedTensorType>(op.getType()))
763 bufferization::ToTensorOp::create(ib, op.getType(), output,
true);
764 else if (scatterDim != 0)
765 memref::DeallocOp::create(ib, mpiInput);
769 rewriter.replaceOp(op, output);
774struct ConvertAllGatherOp :
public CommOpPattern<AllGatherOp> {
775 using CommOpPattern::CommOpPattern;
784 matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter)
const override {
786 SymbolTableCollection symbolTableCollection;
787 FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
791 ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
792 Value input = getAsMemref(adaptor.getInput(), ib,
true);
793 MemRefType inType = cast<MemRefType>(input.
getType());
794 MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
795 auto inputShape = inType.getShape();
796 auto outputShape = outType.getShape();
797 int64_t gatherAxis = adaptor.getGatherAxisAttr().getInt();
798 int64_t inputDimOnAxis = inputShape[gatherAxis];
799 int64_t outputDimOnAxis = outputShape[gatherAxis];
801 for (
size_t i = 0; i < outputShape.size(); ++i)
802 if (outputShape[i] != inputShape[i] && i != (
size_t)gatherAxis)
804 "Result and input shapes must match along non-gather axes.");
805 if (inputDimOnAxis == 0)
806 return op.emitError(
"Input size along the gather axis must be non-zero.");
807 if (inputDimOnAxis == 1) {
808 assert(outputDimOnAxis == inputDimOnAxis);
809 rewriter.replaceOp(op, adaptor.getInput());
812 if (outputDimOnAxis % inputDimOnAxis != 0)
813 return op.emitError(
"Result size along the gather axis must be an exact "
814 "multiple of the input size along the gather axis.");
818 return op.emitError(
"Input/result must be statically shaped memrefs in "
819 "contiguous row-major layout.");
822 Value comm = getComm(*gridOp, adaptor.getGridAxes(), ib);
824 mpi::CommSizeOp::create(ib, ib.getI32Type(), comm).getSize();
825 nRanksV = arith::IndexCastOp::create(ib, ib.getIndexType(), nRanksV);
826 int64_t nRanks = outputDimOnAxis / inputDimOnAxis;
829 arith::CmpIOp::create(ib, arith::CmpIPredicate::eq, nRanksV, nRanksC);
830 cf::AssertOp::create(ib, notError,
831 "Expected number of ranks in the communicator to "
832 "match the output size along the gather axis divided "
833 "by the input size along the gather axis.");
837 SmallVector<int64_t> gatherShape;
838 gatherShape.emplace_back(nRanks);
839 gatherShape.append(inputShape.begin(), inputShape.end());
840 auto gatherType = MemRefType::get(gatherShape, outType.getElementType());
841 Value finalOutput = memref::AllocOp::create(ib, gatherType);
843 mpi::AllGatherOp::create(ib,
TypeRange(), input, finalOutput, comm);
845 if (gatherAxis == 0) {
848 SmallVector<ReassociationIndices> reassociation;
849 reassociation.push_back({0, 1});
850 int64_t numGatherDims = gatherShape.size();
851 for (int64_t i = 2; i < numGatherDims; ++i)
852 reassociation.push_back({i});
853 finalOutput = memref::CollapseShapeOp::create(ib, outType, finalOutput,
857 if (isa<RankedTensorType>(op.getType()))
858 finalOutput = bufferization::ToTensorOp::create(ib, op.getType(),
863 RankedTensorType::get(gatherShape, outType.getElementType());
865 bufferization::ToTensorOp::create(ib, inType, finalOutput,
true);
869 SmallVector<int64_t> outShapePermuted, permutation;
870 for (
int i = 1; i <= gatherAxis; ++i) {
871 outShapePermuted.emplace_back(gatherShape[i]);
872 permutation.emplace_back(i);
874 outShapePermuted.emplace_back(gatherShape[0]);
875 permutation.emplace_back(0);
876 for (
size_t i = gatherAxis + 1; i < gatherShape.size(); ++i) {
877 outShapePermuted.emplace_back(gatherShape[i]);
878 permutation.emplace_back(i);
880 Value permOutput = tensor::EmptyOp::create(ib, outShapePermuted,
881 outType.getElementType());
883 linalg::TransposeOp::create(ib, finalOutput, permOutput, permutation)
888 SmallVector<ReassociationIndices> reassociation;
889 for (int64_t i = 0; i < gatherAxis; ++i) {
890 reassociation.push_back({i});
892 reassociation.push_back({gatherAxis, gatherAxis + 1});
893 for (int64_t i = gatherAxis + 2; i < (int64_t)outShapePermuted.size();
895 reassociation.push_back({i});
898 RankedTensorType::get(outputShape, outType.getElementType());
899 finalOutput = tensor::CollapseShapeOp::create(ib, outTType, finalOutput,
903 if (isa<MemRefType>(op.getType()))
905 bufferization::ToBufferOp::create(ib, outType, finalOutput,
false);
908 rewriter.replaceOp(op, finalOutput);
913struct ConvertUpdateHaloOp :
public OpConversionPattern<UpdateHaloOp> {
914 using OpConversionPattern::OpConversionPattern;
917 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
918 ConversionPatternRewriter &rewriter)
const override {
933 adaptor.getHaloSizes(), rewriter);
934 if (haloSizes.empty()) {
936 rewriter.replaceOp(op, adaptor.getDestination());
940 SymbolTableCollection symbolTableCollection;
941 Location loc = op.getLoc();
944 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
945 if (
auto value = dyn_cast<Value>(v))
947 return arith::ConstantOp::create(
949 rewriter.getIndexAttr(
950 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
953 auto dest = adaptor.getDestination();
954 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
956 if (isa<RankedTensorType>(array.
getType())) {
958 auto mmemrefType = MemRefType::get(
959 dstShape, cast<ShapedType>(array.
getType()).getElementType());
961 bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array);
963 auto rank = cast<ShapedType>(array.
getType()).getRank();
964 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
965 auto grid = adaptor.getGrid();
966 auto gridOp =
getGrid(op, symbolTableCollection);
968 for (
auto &sz : haloSizes) {
969 if (
auto value = dyn_cast<Value>(sz))
970 sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
976 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
977 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
978 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
979 auto currHaloDim = -1;
981 for (
auto i = 0; i < rank; ++i) {
982 auto s = dstShape[i];
983 if (ShapedType::isDynamic(s))
984 shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult();
986 shape[i] = rewriter.getIndexAttr(s);
988 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
991 offsets[i] = haloSizes[currHaloDim * 2];
994 Value _haloSz = arith::AddIOp::create(
995 rewriter, loc, toValue(haloSizes[currHaloDim * 2]),
996 toValue(haloSizes[currHaloDim * 2 + 1]));
999 arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz)
1002 dimSizes[i] = shape[i];
1006 auto tagAttr = rewriter.getI32IntegerAttr(91);
1007 auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr);
1008 auto zeroAttr = rewriter.getI32IntegerAttr(0);
1009 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1011 SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
1012 rewriter.getIndexType());
1014 ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
1017 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
1018 auto splitAxes = opSplitAxes[dim];
1019 if (splitAxes.empty())
1021 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
1024 auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
1025 myMultiIndex, splitAxes)
1028 Value neighbourIDs[2] = {
1029 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
1031 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(),
1034 auto lowerRecvOffset = rewriter.getIndexAttr(0);
1035 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
1036 auto upperRecvOffset =
1037 arith::SubIOp::create(rewriter, loc, toValue(shape[dim]),
1038 toValue(haloSizes[currHaloDim * 2 + 1]));
1039 auto upperSendOffset = arith::SubIOp::create(
1040 rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
1042 Value commWorld = mpi::CommWorldOp::create(
1043 rewriter, loc, mpi::CommType::get(op->getContext()));
1051 auto genSendRecv = [&](
bool upperHalo) {
1052 auto orgOffset = offsets[dim];
1053 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
1054 : haloSizes[currHaloDim * 2];
1057 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
1058 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
1059 auto hasFrom = arith::CmpIOp::create(
1060 rewriter, loc, arith::CmpIPredicate::sge, from, zero);
1061 auto hasTo = arith::CmpIOp::create(rewriter, loc,
1062 arith::CmpIPredicate::sge, to, zero);
1063 auto buffer = memref::AllocOp::create(
1064 rewriter, loc, dimSizes,
1065 cast<ShapedType>(array.
getType()).getElementType());
1068 rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) {
1069 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
1070 : OpFoldResult(upperSendOffset);
1071 auto subview = memref::SubViewOp::create(
1072 builder, loc, array, offsets, dimSizes, strides);
1073 memref::CopyOp::create(builder, loc, subview, buffer);
1074 mpi::SendOp::create(builder, loc,
TypeRange{}, buffer, tag, to,
1076 scf::YieldOp::create(builder, loc);
1080 rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) {
1081 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
1082 : OpFoldResult(lowerRecvOffset);
1083 mpi::RecvOp::create(builder, loc,
TypeRange{}, buffer, tag, from,
1085 auto subview = memref::SubViewOp::create(
1086 builder, loc, array, offsets, dimSizes, strides);
1087 memref::CopyOp::create(builder, loc, buffer, subview);
1088 scf::YieldOp::create(builder, loc);
1090 memref::DeallocOp::create(rewriter, loc, buffer);
1091 offsets[dim] = orgOffset;
1094 auto doSendRecv = [&](
int upOrDown) {
1095 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
1096 Value haloSz = dyn_cast<Value>(v);
1098 haloSz = arith::ConstantOp::create(
1100 rewriter.getI32IntegerAttr(
1101 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
1102 auto hasSize = arith::CmpIOp::create(
1103 rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero);
1104 scf::IfOp::create(rewriter, loc, hasSize,
1105 [&](OpBuilder &builder, Location loc) {
1106 genSendRecv(upOrDown > 0);
1107 scf::YieldOp::create(builder, loc);
1115 dimSizes[dim] = shape[dim];
1117 offsets[dim] = rewriter.getIndexAttr(0);
1122 if (isa<MemRefType>(op.getResult().getType())) {
1123 rewriter.replaceOp(op, array);
1125 assert(isa<RankedTensorType>(op.getResult().getType()));
1126 rewriter.replaceOp(op, bufferization::ToTensorOp::create(
1127 rewriter, loc, op.getResult().getType(), array,
1134struct ConvertShardToMPIPass
1135 :
public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
1139 void runOnOperation()
override {
1141 RewritePatternSet patterns(ctxt);
1146 TypeConverter typeConverter;
1147 typeConverter.addConversion([](Type type) {
return type; });
1150 typeConverter.addConversion(
1151 [](ShardingType type,
1152 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
1153 auto i16 = IntegerType::get(type.getContext(), 16);
1154 auto i64 = IntegerType::get(type.getContext(), 64);
1155 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
1156 ShapedType::kDynamic};
1157 results.emplace_back(RankedTensorType::get(shp, i16));
1158 results.emplace_back(RankedTensorType::get(shp, i64));
1159 results.emplace_back(RankedTensorType::get(shp, i64));
1165 typeConverter.addTargetMaterialization(
1169 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
1170 return SmallVector<Value>();
1171 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
1174 return SmallVector<Value>();
1176 SmallVector<Value> results;
1177 for (
auto oprnd : castOp.getInputs()) {
1178 if (!isa<RankedTensorType>(oprnd.getType()))
1179 return SmallVector<Value>();
1180 results.emplace_back(oprnd);
1186 target.addIllegalDialect<shard::ShardDialect>();
1188 target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
1190 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
1191 arith::ArithDialect, tensor::TensorDialect,
1192 bufferization::BufferizationDialect,
1193 linalg::LinalgDialect, memref::MemRefDialect,
1194 affine::AffineDialect, cf::ControlFlowDialect>();
1196 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1197 return typeConverter.isSignatureLegal(op.getFunctionType());
1199 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
1200 [&](Operation *op) {
return typeConverter.isLegal(op); });
1202 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
1203 ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
1204 ConvertAllGatherOp, ConvertAllReduceOp, ConvertReduceScatterOp,
1205 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
1206 SymbolTableCollection stc;
1210 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1211 patterns, typeConverter);
1215 (void)applyPartialConversion(getOperation(),
target, std::move(patterns));
1219 SymbolTableCollection symbolTableCollection;
static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateAllSliceOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
TypedValue< IndexType > createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
SmallVector< Value > getMixedAsValues(OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...