36 #define DEBUG_TYPE "mesh-to-mpi"
37 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
40 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
41 #include "mlir/Conversion/Passes.h.inc"
55 auto dyn = dynamics.begin();
60 "expected an i64 or an intex type");
61 for (
auto s : statics) {
62 if (s == ShapedType::kDynamic) {
63 values.emplace_back(*(dyn++));
66 values.emplace_back(b.
create<arith::ConstantOp>(loc, type, val));
76 int n = dimensions.size();
79 for (
int i = n - 1; i >= 0; --i) {
80 multiIndex[i] = b.
create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
82 linearIndex = b.
create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
92 Value linearIndex = b.
create<arith::ConstantIndexOp>(loc, 0);
93 Value stride = b.
create<arith::ConstantIndexOp>(loc, 1);
95 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
96 Value off = b.
create<arith::MulIOp>(loc, multiIndex[i], stride);
97 linearIndex = b.
create<arith::AddIOp>(loc, linearIndex, off);
98 stride = b.
create<arith::MulIOp>(loc, stride, dimensions[i]);
109 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
111 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
114 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
118 rewriter.
replaceOp(op, shardingOp.getResult());
130 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
132 auto splitAxes = op.getSplitAxes().getAxes();
133 int64_t maxNAxes = 0;
134 for (
auto axes : splitAxes)
135 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
143 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
145 Value resSplitAxes = rewriter.
create<tensor::EmptyOp>(loc, shape, i16);
147 Value fillValue = rewriter.
create<arith::ConstantOp>(loc, i16, attr);
148 resSplitAxes = rewriter.
create<linalg::FillOp>(loc, fillValue, resSplitAxes)
152 std::array<int64_t, 2> strides = {1, 1};
156 int64_t size = axes.size();
159 std::array<int64_t, 2> offs = {(int64_t)i, 0};
160 std::array<int64_t, 2> sizes = {1, size};
163 auto vals = rewriter.
create<arith::ConstantOp>(loc, tensorType, attrs);
164 resSplitAxes = rewriter.
create<tensor::InsertSliceOp>(
165 loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
171 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
172 adaptor.getDynamicHaloSizes());
177 .
create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
180 : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
189 if (adaptor.getStaticShardedDimsOffsets().empty()) {
190 resOffsets = rewriter.
create<tensor::EmptyOp>(
191 loc, std::array<int64_t, 2>{0, 0}, i64);
194 auto meshOp =
getMesh(op, symbolTableCollection);
195 int64_t maxSplitSize = 0;
196 for (
auto axes : splitAxes) {
199 assert(splitSize != ShapedType::kDynamic);
200 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
202 assert(maxSplitSize);
205 resOffsets = rewriter.
create<tensor::EmptyOp>(
206 loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
210 rewriter.
create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
212 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
213 adaptor.getDynamicShardedDimsOffsets());
218 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
221 Value vals = rewriter.
create<tensor::FromElementsOp>(loc, values);
222 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
223 std::array<int64_t, 2> sizes = {1, splitSize};
224 resOffsets = rewriter.
create<tensor::InsertSliceOp>(
225 loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
232 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
237 rewriter.
create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
239 rewriter.
create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
240 resOffsets = rewriter.
create<tensor::CastOp>(loc, resTypes[2], resOffsets);
244 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
250 struct ConvertProcessMultiIndexOp
255 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
262 auto meshOp =
getMesh(op, symbolTableCollection);
264 if (ShapedType::isDynamicShape(meshOp.getShape()))
269 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
272 Value rank = rewriter.
create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
273 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
276 auto axes = adaptor.getAxes();
279 for (
auto axis : axes) {
280 subIndex.emplace_back(mIdx[axis]);
282 mIdx = std::move(subIndex);
290 class ConvertProcessLinearIndexOp
298 ConvertProcessLinearIndexOp(
const TypeConverter &typeConverter,
303 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
307 if (worldRank >= 0) {
313 auto ctx = op.getContext();
329 struct ConvertNeighborsLinearIndicesOp
334 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
341 auto axes = adaptor.getSplitAxes();
343 if (axes.size() != 1)
348 auto meshOp =
getMesh(op, symbolTableCollection);
349 auto mIdx = adaptor.getDevice();
350 auto orgIdx = mIdx[axes[0]];
353 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
354 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
356 Value dimSz = dims[axes[0]];
357 Value one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
358 Value minus1 = rewriter.
create<arith::ConstantIndexOp>(loc, -1);
360 loc, arith::CmpIPredicate::sle, orgIdx,
361 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
362 auto down = rewriter.
create<scf::IfOp>(
365 builder.create<scf::YieldOp>(loc, minus1);
370 rewriter.
create<arith::SubIOp>(op.getLoc(), orgIdx, one)
372 builder.create<scf::YieldOp>(
373 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
375 atBorder = rewriter.
create<arith::CmpIOp>(
376 loc, arith::CmpIPredicate::sge, orgIdx,
377 rewriter.
create<arith::SubIOp>(loc, dimSz, one).getResult());
378 auto up = rewriter.
create<scf::IfOp>(
381 builder.create<scf::YieldOp>(loc, minus1);
386 rewriter.
create<arith::AddIOp>(op.getLoc(), orgIdx, one);
387 builder.
create<scf::YieldOp>(
388 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
399 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
401 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
403 return op->emitError()
404 <<
"Expected SharingOp as defining op for sharding"
405 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
425 for (
auto dim : adaptor.getDimsDynamic()) {
427 dynDims.emplace_back(llvm::getSingleElement(dim));
430 for (
auto device : adaptor.getDeviceDynamic()) {
431 dynDevice.emplace_back(llvm::getSingleElement(device));
437 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
439 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
443 auto meshOp =
getMesh(sharding, symbolTableCollection);
445 if (ShapedType::isDynamicShape(meshOp.getShape()))
448 auto splitAxes = sharding.getSplitAxes().getAxes();
454 Value shardedDimsOffs;
457 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
458 sharding.getDynamicShardedDimsOffsets(), index);
460 shardedDimsOffs = rewriter.
create<tensor::FromElementsOp>(
478 if (i < splitAxes.size() && !splitAxes[i].empty()) {
479 auto axes = splitAxes[i];
485 Value idx = multiIdx[axes[0]];
488 if (shardedDimsOffs) {
491 if (axes.size() > 1) {
492 return op->emitError() <<
"Only single axis sharding is "
493 <<
"supported for each dimension.";
495 idx = rewriter.
create<arith::AddIOp>(loc, posVal, idx);
498 rewriter.
create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
499 idx = rewriter.
create<arith::AddIOp>(loc, idx, one);
501 rewriter.
create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
502 Value sz = rewriter.
create<arith::SubIOp>(loc, nextOff, off);
505 Value numShardsVal = rewriter.
create<arith::ConstantOp>(
511 Value sz = rewriter.
create<arith::DivSIOp>(loc, dim, numShardsVal);
512 Value sz1 = rewriter.
create<arith::RemSIOp>(loc, dim, numShardsVal);
513 sz1 = rewriter.
create<arith::SubIOp>(loc, numShardsVal, sz1);
514 auto cond = rewriter.
create<arith::CmpIOp>(
515 loc, arith::CmpIPredicate::sge, idx, sz1);
516 Value odd = rewriter.
create<arith::SelectOp>(loc, cond, one, zero);
517 sz = rewriter.
create<arith::AddIOp>(loc, sz, odd);
520 pos += numShards + 1;
536 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
552 adaptor.getHaloSizes(), rewriter);
553 if (haloSizes.empty()) {
555 rewriter.
replaceOp(op, adaptor.getDestination());
564 if (
auto value = dyn_cast<Value>(v))
566 return rewriter.
create<arith::ConstantOp>(
568 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
571 auto dest = adaptor.getDestination();
572 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
574 if (isa<RankedTensorType>(array.
getType())) {
577 dstShape, cast<ShapedType>(array.
getType()).getElementType());
579 rewriter.
create<bufferization::ToMemrefOp>(loc, tensorType, array);
581 auto rank = cast<ShapedType>(array.
getType()).getRank();
582 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
583 auto mesh = adaptor.getMesh();
584 auto meshOp =
getMesh(op, symbolTableCollection);
586 for (
auto &sz : haloSizes) {
587 if (
auto value = dyn_cast<Value>(sz))
598 auto currHaloDim = -1;
600 for (
auto i = 0; i < rank; ++i) {
601 auto s = dstShape[i];
602 if (ShapedType::isDynamic(s))
603 shape[i] = rewriter.
create<memref::DimOp>(loc, array, s).getResult();
607 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
610 offsets[i] = haloSizes[currHaloDim * 2];
614 loc, toValue(haloSizes[currHaloDim * 2]),
615 toValue(haloSizes[currHaloDim * 2 + 1]));
618 rewriter.
create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
621 dimSizes[i] = shape[i];
626 auto tag = rewriter.
create<arith::ConstantOp>(loc, tagAttr);
628 auto zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
633 rewriter.
create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
636 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
637 auto splitAxes = opSplitAxes[dim];
638 if (splitAxes.empty())
640 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
644 .
create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
648 Value neighbourIDs[2] = {rewriter.
create<arith::IndexCastOp>(
650 rewriter.
create<arith::IndexCastOp>(
654 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
655 auto upperRecvOffset = rewriter.
create<arith::SubIOp>(
656 loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
657 auto upperSendOffset = rewriter.
create<arith::SubIOp>(
658 loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
660 Value commWorld = rewriter.
create<mpi::CommWorldOp>(
669 auto genSendRecv = [&](
bool upperHalo) {
670 auto orgOffset = offsets[dim];
671 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
672 : haloSizes[currHaloDim * 2];
675 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
676 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
677 auto hasFrom = rewriter.
create<arith::CmpIOp>(
678 loc, arith::CmpIPredicate::sge, from, zero);
679 auto hasTo = rewriter.
create<arith::CmpIOp>(
680 loc, arith::CmpIPredicate::sge, to, zero);
681 auto buffer = rewriter.
create<memref::AllocOp>(
682 loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
684 rewriter.
create<scf::IfOp>(
686 offsets[dim] = upperHalo ?
OpFoldResult(lowerSendOffset)
688 auto subview = builder.create<memref::SubViewOp>(
689 loc, array, offsets, dimSizes, strides);
690 builder.create<memref::CopyOp>(loc, subview, buffer);
691 builder.create<mpi::SendOp>(loc,
TypeRange{}, buffer, tag, to,
693 builder.create<scf::YieldOp>(loc);
696 rewriter.
create<scf::IfOp>(
698 offsets[dim] = upperHalo ?
OpFoldResult(upperRecvOffset)
700 builder.create<mpi::RecvOp>(loc,
TypeRange{}, buffer, tag, from,
702 auto subview = builder.create<memref::SubViewOp>(
703 loc, array, offsets, dimSizes, strides);
704 builder.create<memref::CopyOp>(loc, buffer, subview);
705 builder.create<scf::YieldOp>(loc);
707 rewriter.
create<memref::DeallocOp>(loc, buffer);
708 offsets[dim] = orgOffset;
711 auto doSendRecv = [&](
int upOrDown) {
712 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
713 Value haloSz = dyn_cast<Value>(v);
715 haloSz = rewriter.
create<arith::ConstantOp>(
717 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
718 auto hasSize = rewriter.
create<arith::CmpIOp>(
719 loc, arith::CmpIPredicate::sgt, haloSz, zero);
720 rewriter.
create<scf::IfOp>(loc, hasSize,
722 genSendRecv(upOrDown > 0);
723 builder.create<scf::YieldOp>(loc);
731 dimSizes[dim] = shape[dim];
738 if (isa<MemRefType>(op.getResult().getType())) {
741 assert(isa<RankedTensorType>(op.getResult().getType()));
743 loc, op.getResult().getType(), array,
750 struct ConvertMeshToMPIPass
751 :
public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
755 void runOnOperation()
override {
756 uint64_t worldRank = -1;
761 dlti::query(getOperation(), {
"MPI:comm_world_rank"},
false);
762 if (succeeded(dltiAttr)) {
763 if (!isa<IntegerAttr>(dltiAttr.value())) {
764 getOperation()->emitError()
765 <<
"Expected an integer attribute for MPI:comm_world_rank";
766 return signalPassFailure();
768 worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
783 [](ShardingType type,
787 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
788 ShapedType::kDynamic};
801 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
803 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
809 for (
auto oprnd : castOp.getInputs()) {
810 if (!isa<RankedTensorType>(oprnd.getType()))
812 results.emplace_back(oprnd);
818 target.addIllegalDialect<mesh::MeshDialect>();
822 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
823 arith::ArithDialect, tensor::TensorDialect,
824 bufferization::BufferizationDialect,
825 linalg::LinalgDialect, memref::MemRefDialect>();
827 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
830 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
833 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
834 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
835 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
837 patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
839 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
static MLIRContext * getContext(OpFoldResult val)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
TypedAttr getOneAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.