36 #define DEBUG_TYPE "mesh-to-mpi"
37 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
40 #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
41 #include "mlir/Conversion/Passes.h.inc"
55 auto dyn = dynamics.begin();
60 "expected an i64 or an intex type");
61 for (
auto s : statics) {
62 if (s == ShapedType::kDynamic) {
63 values.emplace_back(*(dyn++));
66 values.emplace_back(b.
create<arith::ConstantOp>(loc, type, val));
76 int n = dimensions.size();
79 for (
int i = n - 1; i >= 0; --i) {
80 multiIndex[i] = b.
create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
82 linearIndex = b.
create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
92 Value linearIndex = b.
create<arith::ConstantIndexOp>(loc, 0);
93 Value stride = b.
create<arith::ConstantIndexOp>(loc, 1);
95 for (
int i = multiIndex.size() - 1; i >= 0; --i) {
96 Value off = b.
create<arith::MulIOp>(loc, multiIndex[i], stride);
97 linearIndex = b.
create<arith::AddIOp>(loc, linearIndex, off);
98 stride = b.
create<arith::MulIOp>(loc, stride, dimensions[i]);
109 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
111 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
114 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
118 rewriter.
replaceOp(op, shardingOp.getResult());
130 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
132 auto splitAxes = op.getSplitAxes().getAxes();
133 int64_t maxNAxes = 0;
134 for (
auto axes : splitAxes)
135 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
143 std::array<int64_t, 2> shape = {
static_cast<int64_t
>(splitAxes.size()),
145 Value resSplitAxes = rewriter.
create<tensor::EmptyOp>(loc, shape, i16);
147 Value fillValue = rewriter.
create<arith::ConstantOp>(loc, i16, attr);
148 resSplitAxes = rewriter.
create<linalg::FillOp>(loc, fillValue, resSplitAxes)
152 std::array<int64_t, 2> strides = {1, 1};
156 int64_t size = axes.size();
159 std::array<int64_t, 2> offs = {(int64_t)i, 0};
160 std::array<int64_t, 2> sizes = {1, size};
163 auto vals = rewriter.
create<arith::ConstantOp>(loc, tensorType, attrs);
164 resSplitAxes = rewriter.
create<tensor::InsertSliceOp>(
165 loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
171 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
172 adaptor.getDynamicHaloSizes());
177 .
create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
180 : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
189 if (adaptor.getStaticShardedDimsOffsets().empty()) {
190 resOffsets = rewriter.
create<tensor::EmptyOp>(
191 loc, std::array<int64_t, 2>{0, 0}, i64);
194 auto meshOp =
getMesh(op, symbolTableCollection);
195 int64_t maxSplitSize = 0;
196 for (
auto axes : splitAxes) {
199 assert(splitSize != ShapedType::kDynamic);
200 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
202 assert(maxSplitSize);
205 resOffsets = rewriter.
create<tensor::EmptyOp>(
206 loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
210 rewriter.
create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
212 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
213 adaptor.getDynamicShardedDimsOffsets());
218 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
221 Value vals = rewriter.
create<tensor::FromElementsOp>(loc, values);
222 std::array<int64_t, 2> offs = {
static_cast<int64_t
>(i), 0};
223 std::array<int64_t, 2> sizes = {1, splitSize};
224 resOffsets = rewriter.
create<tensor::InsertSliceOp>(
225 loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
232 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
237 rewriter.
create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
239 rewriter.
create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
240 resOffsets = rewriter.
create<tensor::CastOp>(loc, resTypes[2], resOffsets);
244 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
250 struct ConvertProcessMultiIndexOp
255 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
262 auto meshOp =
getMesh(op, symbolTableCollection);
264 if (ShapedType::isDynamicShape(meshOp.getShape()))
269 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
272 Value rank = rewriter.
create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
273 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
276 auto axes = adaptor.getAxes();
279 for (
auto axis : axes) {
280 subIndex.emplace_back(mIdx[axis]);
282 mIdx = std::move(subIndex);
290 class ConvertProcessLinearIndexOp
298 ConvertProcessLinearIndexOp(
const TypeConverter &typeConverter,
303 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
307 if (worldRank >= 0) {
324 struct ConvertNeighborsLinearIndicesOp
329 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
336 auto axes = adaptor.getSplitAxes();
338 if (axes.size() != 1)
343 auto meshOp =
getMesh(op, symbolTableCollection);
344 auto mIdx = adaptor.getDevice();
345 auto orgIdx = mIdx[axes[0]];
348 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
349 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
351 Value dimSz = dims[axes[0]];
352 Value one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
353 Value minus1 = rewriter.
create<arith::ConstantIndexOp>(loc, -1);
355 loc, arith::CmpIPredicate::sle, orgIdx,
356 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
357 auto down = rewriter.
create<scf::IfOp>(
360 builder.create<scf::YieldOp>(loc, minus1);
365 rewriter.
create<arith::SubIOp>(op.getLoc(), orgIdx, one)
367 builder.create<scf::YieldOp>(
368 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
370 atBorder = rewriter.
create<arith::CmpIOp>(
371 loc, arith::CmpIPredicate::sge, orgIdx,
372 rewriter.
create<arith::SubIOp>(loc, dimSz, one).getResult());
373 auto up = rewriter.
create<scf::IfOp>(
376 builder.create<scf::YieldOp>(loc, minus1);
381 rewriter.
create<arith::AddIOp>(op.getLoc(), orgIdx, one);
382 builder.
create<scf::YieldOp>(
383 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
394 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
396 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
398 return op->emitError()
399 <<
"Expected SharingOp as defining op for sharding"
400 <<
" but found " << adaptor.getSharding()[0].getDefiningOp();
420 for (
auto dim : adaptor.getDimsDynamic()) {
422 assert(dim.size() == 1);
423 dynDims.emplace_back(dim[0]);
426 for (
auto device : adaptor.getDeviceDynamic()) {
427 assert(device.size() == 1);
428 dynDevice.emplace_back(device[0]);
434 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
436 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
440 auto meshOp =
getMesh(sharding, symbolTableCollection);
442 if (ShapedType::isDynamicShape(meshOp.getShape()))
445 auto splitAxes = sharding.getSplitAxes().getAxes();
451 Value shardedDimsOffs;
454 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
455 sharding.getDynamicShardedDimsOffsets(), index);
457 shardedDimsOffs = rewriter.
create<tensor::FromElementsOp>(
475 if (i < splitAxes.size() && !splitAxes[i].empty()) {
476 auto axes = splitAxes[i];
482 Value idx = multiIdx[axes[0]];
485 if (shardedDimsOffs) {
488 if (axes.size() > 1) {
489 return op->emitError() <<
"Only single axis sharding is "
490 <<
"supported for each dimension.";
492 idx = rewriter.
create<arith::AddIOp>(loc, posVal, idx);
495 rewriter.
create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
496 idx = rewriter.
create<arith::AddIOp>(loc, idx, one);
498 rewriter.
create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
499 Value sz = rewriter.
create<arith::SubIOp>(loc, nextOff, off);
502 Value numShardsVal = rewriter.
create<arith::ConstantOp>(
508 Value sz = rewriter.
create<arith::DivSIOp>(loc, dim, numShardsVal);
509 Value sz1 = rewriter.
create<arith::RemSIOp>(loc, dim, numShardsVal);
510 sz1 = rewriter.
create<arith::SubIOp>(loc, numShardsVal, sz1);
511 auto cond = rewriter.
create<arith::CmpIOp>(
512 loc, arith::CmpIPredicate::sge, idx, sz1);
513 Value odd = rewriter.
create<arith::SelectOp>(loc, cond, one, zero);
514 sz = rewriter.
create<arith::AddIOp>(loc, sz, odd);
517 pos += numShards + 1;
533 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
549 adaptor.getHaloSizes(), rewriter);
550 if (haloSizes.empty()) {
552 rewriter.
replaceOp(op, adaptor.getDestination());
561 if (
auto value = dyn_cast<Value>(v))
563 return rewriter.
create<arith::ConstantOp>(
565 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
568 auto dest = adaptor.getDestination();
569 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
571 if (isa<RankedTensorType>(array.
getType())) {
574 dstShape, cast<ShapedType>(array.
getType()).getElementType());
576 rewriter.
create<bufferization::ToMemrefOp>(loc, tensorType, array);
578 auto rank = cast<ShapedType>(array.
getType()).getRank();
579 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
580 auto mesh = adaptor.getMesh();
581 auto meshOp =
getMesh(op, symbolTableCollection);
583 for (
auto &sz : haloSizes) {
584 if (
auto value = dyn_cast<Value>(sz))
595 auto currHaloDim = -1;
597 for (
auto i = 0; i < rank; ++i) {
598 auto s = dstShape[i];
599 if (ShapedType::isDynamic(s))
600 shape[i] = rewriter.
create<memref::DimOp>(loc, array, s).getResult();
604 if ((
size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
607 offsets[i] = haloSizes[currHaloDim * 2];
611 loc, toValue(haloSizes[currHaloDim * 2]),
612 toValue(haloSizes[currHaloDim * 2 + 1]));
615 rewriter.
create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
618 dimSizes[i] = shape[i];
623 auto tag = rewriter.
create<arith::ConstantOp>(loc, tagAttr);
625 auto zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
630 rewriter.
create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
633 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
634 auto splitAxes = opSplitAxes[dim];
635 if (splitAxes.empty())
637 assert(currHaloDim >= 0 && (
size_t)currHaloDim < haloSizes.size() / 2);
641 .
create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
645 Value neighbourIDs[2] = {rewriter.
create<arith::IndexCastOp>(
647 rewriter.
create<arith::IndexCastOp>(
651 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
652 auto upperRecvOffset = rewriter.
create<arith::SubIOp>(
653 loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
654 auto upperSendOffset = rewriter.
create<arith::SubIOp>(
655 loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
663 auto genSendRecv = [&](
bool upperHalo) {
664 auto orgOffset = offsets[dim];
665 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
666 : haloSizes[currHaloDim * 2];
669 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
670 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
671 auto hasFrom = rewriter.
create<arith::CmpIOp>(
672 loc, arith::CmpIPredicate::sge, from, zero);
673 auto hasTo = rewriter.
create<arith::CmpIOp>(
674 loc, arith::CmpIPredicate::sge, to, zero);
675 auto buffer = rewriter.
create<memref::AllocOp>(
676 loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
678 rewriter.
create<scf::IfOp>(
680 offsets[dim] = upperHalo ?
OpFoldResult(lowerSendOffset)
682 auto subview = builder.create<memref::SubViewOp>(
683 loc, array, offsets, dimSizes, strides);
684 builder.create<memref::CopyOp>(loc, subview, buffer);
685 builder.create<mpi::SendOp>(loc,
TypeRange{}, buffer, tag, to);
686 builder.create<scf::YieldOp>(loc);
689 rewriter.
create<scf::IfOp>(
691 offsets[dim] = upperHalo ?
OpFoldResult(upperRecvOffset)
693 builder.create<mpi::RecvOp>(loc,
TypeRange{}, buffer, tag, from);
694 auto subview = builder.create<memref::SubViewOp>(
695 loc, array, offsets, dimSizes, strides);
696 builder.create<memref::CopyOp>(loc, buffer, subview);
697 builder.create<scf::YieldOp>(loc);
699 rewriter.
create<memref::DeallocOp>(loc, buffer);
700 offsets[dim] = orgOffset;
703 auto doSendRecv = [&](
int upOrDown) {
704 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
705 Value haloSz = dyn_cast<Value>(v);
707 haloSz = rewriter.
create<arith::ConstantOp>(
709 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
710 auto hasSize = rewriter.
create<arith::CmpIOp>(
711 loc, arith::CmpIPredicate::sgt, haloSz, zero);
712 rewriter.
create<scf::IfOp>(loc, hasSize,
714 genSendRecv(upOrDown > 0);
715 builder.create<scf::YieldOp>(loc);
723 dimSizes[dim] = shape[dim];
730 if (isa<MemRefType>(op.getResult().getType())) {
733 assert(isa<RankedTensorType>(op.getResult().getType()));
735 loc, op.getResult().getType(), array,
742 struct ConvertMeshToMPIPass
743 :
public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
747 void runOnOperation()
override {
748 uint64_t worldRank = -1;
753 dlti::query(getOperation(), {
"MPI:comm_world_rank"},
false);
754 if (succeeded(dltiAttr)) {
755 if (!isa<IntegerAttr>(dltiAttr.value())) {
756 getOperation()->emitError()
757 <<
"Expected an integer attribute for MPI:comm_world_rank";
758 return signalPassFailure();
760 worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
775 [](ShardingType type,
779 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
780 ShapedType::kDynamic};
793 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
795 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
801 for (
auto oprnd : castOp.getInputs()) {
802 if (!isa<RankedTensorType>(oprnd.getType()))
804 results.emplace_back(oprnd);
810 target.addIllegalDialect<mesh::MeshDialect>();
814 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
815 arith::ArithDialect, tensor::TensorDialect,
816 bufferization::BufferizationDialect,
817 linalg::LinalgDialect, memref::MemRefDialect>();
819 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
822 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
825 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
826 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
827 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
829 patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
831 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
static MLIRContext * getContext(OpFoldResult val)
static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef< int64_t > shardedDimsOffsets={}, ArrayRef< int64_t > haloSizes={})
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
TypedAttr getOneAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.