29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
39template <
typename SourceAxes,
typename TargetAxes>
41 const TargetAxes &targetAxes) {
42 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
43 return sourceAxes.contains(targetAxis);
53 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
57 auto targetSplitAxes =
58 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
59 targetSplitAxes.push_back(splitGridAxis);
60 targetShardingSplitAxes[splitTensorAxis] =
68static std::tuple<TypedValue<ShapedType>,
Sharding>
74 AllSliceOp::create(builder, sourceShard, grid,
78 builder.
getContext(), std::move(sourceSharding), splitTensorAxis,
80 return {targetShard, targetSharding};
88static std::optional<std::tuple<int64_t, GridAxis>>
91 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
94 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
104 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
109 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
113 return std::make_tuple(
115 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
120static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
126 sourceSharding, std::move(targetSharding))) {
127 auto [tensorAxis, gridAxis] = detectRes.value();
129 tensorAxis, gridAxis);
140static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
145 for (
size_t tensorDim = 0; tensorDim < srcSize; ++tensorDim) {
146 auto srcSplitAxes = srcSharding.
getSplitAxes()[tensorDim].asArrayRef();
148 auto tgtSplitAxes = tgtSharding.
getSplitAxes()[tensorDim].asArrayRef();
151 if (srcSplitAxes.size() <= tgtSplitAxes.size())
155 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
156 srcSplitAxes.begin()))
158 dimOff = tgtSplitAxes.size();
162 if (srcSplitAxes.size() == 0)
170 return std::make_tuple(tensorDim, unsplitAxes);
179 size_t numUnsplitAxes) {
182 assert(
static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
184 assert(srcSplitAxes.size() >= numUnsplitAxes);
185 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
187 srcSplitAxes.begin() + numSplitAxes);
198 for (
GridAxis gridAxis : unsplitAxes)
199 targetShape[splitTensorDim] =
201 return sourceType.cloneWith(targetShape, sourceType.getElementType());
214 ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
216 sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
217 Value allGatherResult = AllGatherOp::create(
219 RankedTensorType::get(allGatherResultType.getShape(),
220 allGatherResultType.getElementType()),
221 grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
222 ShapedType targetType =
225 tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
226 return {targetShard, targetSharding};
229static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
233 ShapedType sourceUnshardedShape,
236 sourceSharding, std::move(targetSharding))) {
237 auto [tensorDim, gridAxes] = detectRes.value();
239 sourceUnshardedShape, sourceShard, grid,
240 tensorDim, gridAxes);
251static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
254 for (
size_t sourceTensorAxis = 0;
255 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
256 ++sourceTensorAxis) {
257 for (
size_t targetTensorAxis = 0;
258 targetTensorAxis < targetSharding.
getSplitAxes().size();
259 ++targetTensorAxis) {
260 if (sourceTensorAxis == targetTensorAxis)
262 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
263 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
264 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
270 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
277 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
285 return std::make_tuple(
286 sourceTensorAxis, targetTensorAxis,
287 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
299 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
304 auto sourceSplitAxes =
305 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
306 assert(!sourceSplitAxes.empty());
307 auto gridAxis = sourceSplitAxes.back();
308 sourceSplitAxes.pop_back();
309 targetShardingSplitAxes[sourceTensorAxis] =
312 auto targetSplitAxes =
313 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
314 targetSplitAxes.push_back(gridAxis);
315 targetShardingSplitAxes[targetTensorAxis] =
326 targetShape[sourceTensorAxis] =
328 targetShape[targetTensorAxis] =
330 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
333static std::tuple<TypedValue<ShapedType>, Sharding>
336 ShapedType sourceUnshardedShape,
344 ctx, std::move(sourceSharding), sourceTensorAxis, targetTensorAxis);
346 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
348 Value allToAllResult = AllToAllOp::create(
350 RankedTensorType::get(allToAllResultShape.getShape(),
351 allToAllResultShape.getElementType()),
353 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
354 ShapedType targetShape =
357 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
358 return {targetShard, targetSharding};
361static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
365 ShapedType sourceUnshardedShape,
368 sourceSharding, std::move(targetSharding))) {
369 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
371 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
372 sourceTensorAxis, targetTensorAxis, gridAxis);
382static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
386 ShapedType sourceUnshardedShape,
399 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
400 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
401 ShapedType::isStaticShape(tgtHaloSizes) &&
402 sourceShard.getType().hasStaticShape()) &&
403 "dynamic shapes/halos are not supported yet for shard-partition");
404 auto rank = sourceShard.getType().getRank();
407 strides(rank, 1), outShape(sourceShard.getType().getShape()),
408 coreShape(sourceShard.getType().getShape());
412 for (
auto i = 0u; i < rank; ++i) {
413 if (i < splitAxes.size() && !splitAxes[i].empty()) {
414 if (!srcHaloSizes.empty()) {
415 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
416 srcCoreOffs[i] = srcHaloSizes[i * 2];
418 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
420 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
427 tensor::EmptyOp::create(builder, sourceShard.
getLoc(), outShape,
428 sourceShard.getType().getElementType());
429 auto core = tensor::ExtractSliceOp::create(
430 builder, sourceShard.
getLoc(),
431 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
432 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
433 auto initOprnd = tensor::InsertSliceOp::create(
434 builder, sourceShard.
getLoc(), core, initVal, noVals, noVals, noVals,
435 tgtCoreOffs, coreShape, strides);
438 auto updateHaloResult =
439 UpdateHaloOp::create(
440 builder, sourceShard.
getLoc(),
441 RankedTensorType::get(outShape,
442 sourceShard.getType().getElementType()),
443 initOprnd, grid.getSymName(),
469 builder, grid, sourceSharding, targetSharding,
470 sourceUnshardedValue.getType(), sourceShard)) {
471 return std::get<0>(tryRes.value());
474 assert(sourceShard.getType() ==
475 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
476 [[maybe_unused]] ShapedType targetShardType =
478 assert(sourceShard.getType().getRank() == targetShardType.getRank());
487 builder, grid, sourceSharding, targetSharding,
488 sourceUnshardedValue.getType(), sourceShard)) {
489 std::tie(targetShard, actualTargetSharding) = tryRes.value();
490 }
else if (
auto tryRes =
492 targetSharding, sourceShard)) {
493 std::tie(targetShard, actualTargetSharding) = tryRes.value();
495 builder, grid, sourceSharding, targetSharding,
496 sourceUnshardedValue.getType(), sourceShard)) {
497 std::tie(targetShard, actualTargetSharding) = tryRes.value();
501 assert(targetShard &&
"Did not find any pattern to apply.");
502 assert(actualTargetSharding == targetSharding);
503 assert(targetShard.getType() == targetShardType);
510 assert(source.getResult() ==
target.getSrc());
511 auto sourceSharding = source.getSharding();
512 auto targetSharding =
target.getSharding();
514 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
515 source.getSrc(), sourceShardValue);
522 GridOp srcGrid =
getGrid(source, symbolTableCollection);
523 assert(srcGrid && srcGrid ==
getGrid(
target, symbolTableCollection));
524 return reshard(builder, srcGrid, source,
target, sourceShardValue);
528 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
531#define GEN_PASS_DEF_PARTITION
532#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
546 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
547 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
548 return arg.getType();
551 assert(rankedTensorArg.hasOneUse());
553 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
555 GridOp grid =
getGrid(shardOp, symbolTableCollection);
557 shardOp.getSharding()));
568 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
569 if (!shardingInterface) {
573 resultShardings, partitionMap,
574 symbolTableCollection, builder);
576 if (failed(shardingInterface.partition(
577 partitionedOperands, operandShardings, resultShardings,
578 partitionMap, symbolTableCollection, builder))) {
584 return partitionMap.contains(result);
593 std::vector<Sharding> res;
595 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
596 TypedValue<RankedTensorType> rankedTensor =
597 dyn_cast<TypedValue<RankedTensorType>>(operand);
598 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
604 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
605 return Sharding(shardOp.getSharding());
613 std::vector<Sharding> res;
617 if (!result.hasOneUse() || result.use_empty()) {
626 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
628 return Sharding(shardOp.getSharding());
630 if (rankedTensor.getType().getRank() == 0) {
635 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
636 return Sharding(sharding.getGridAttr());
649 Value targetPartitionValue;
653 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
655 targetPartitionValue = partitionMap.
lookup(shardOp.getSrc());
659 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
660 targetPartitionValue =
reshard(builder, srcShardOp, shardOp,
661 srcPartitionValue, symbolTableCollection);
664 assert(!partitionMap.
contains(shardOp.getResult()));
665 partitionMap.
map(shardOp.getResult(), targetPartitionValue);
675 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
676 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
679 if (rankedTensorArg.getNumUses() > 1)
681 <<
"Cannot partition: expected a single use for block argument "
682 << arg.getArgNumber() <<
" in block "
685 auto shardOp = dyn_cast<ShardOp>(useOp);
688 <<
"Cannot partition: expected a shard.shard op for block "
689 <<
"argument " << arg.getArgNumber() <<
" in block "
710 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
711 if (!rankedTT || rankedTT.getRank() == 0)
714 auto shard = operand.get().getDefiningOp<ShardOp>();
716 return op->
emitError() <<
"Cannot partition: tensor operand "
717 << operand.getOperandNumber()
718 <<
" must be defined by a shard.shard operation.";
719 if (!
shard.getAnnotateForUsers())
721 <<
"Cannot partition: shard.shard for operand "
722 << operand.getOperandNumber() <<
" must set 'annotate_for_users'.";
727 <<
"Cannot partition: result " <<
result.getResultNumber()
728 <<
" must have exactly one use.";
729 auto shard = dyn_cast<ShardOp>(*
result.user_begin());
732 <<
"Cannot partition: user of result " <<
result.getResultNumber()
733 <<
" must be shard.shard operation.";
734 if (
shard.getAnnotateForUsers())
735 return op->
emitError() <<
"Cannot partition: shard.shard for result "
736 <<
result.getResultNumber()
737 <<
" must not set 'annotate_for_users'.";
746 if (isa<ShardingOp>(op)) {
750 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
751 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
753 return op.
emitError(
"expected a shard op as source of get_sharding");
755 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
756 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
760 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
771 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
772 [&partitionMap](
Value operand) {
773 assert(partitionMap.contains(operand));
774 return partitionMap.lookup(operand);
778 symbolTableCollection, builder);
790 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
795 for (
auto [unshardedBlockArg, partitionedBlockArg] :
797 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
820 for (
Block &
b : op.getBlocks()) {
821 if (llvm::any_of(
b.getOperations(),
822 [](
Operation &op) { return isa<ShardOp>(op); })) {
823 originalBlocks.push_back(&
b);
827 for (
Block *block : originalBlocks) {
828 if (failed(
partitionBlock(*block, partitionMap, symbolTableCollection,
834 for (
Block *block : originalBlocks) {
841 for (
Block &block : op.getFunctionBody()) {
847 returnOp = &block.back();
852 op.setType(FunctionType::get(
853 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
862struct Partition :
public impl::PartitionBase<Partition> {
863 void runOnOperation()
override {
866 if (failed(partitionFuncOp(getOperation(), partitionMap,
867 symbolTableCollection))) {
868 return signalPassFailure();
872 void getDependentDialects(DialectRegistry ®istry)
const override {
874 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Location getLoc()
Return a location for this region.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorDim, ArrayRef< GridAxis > unsplitAxes)
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultTypeInUnsplitLastAxes(ShapedType sourceType, int64_t splitTensorDim, ArrayRef< int64_t > gridShape, ArrayRef< GridAxis > unsplitAxes)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult checkFullyAnnotated(Block &block)
static std::optional< std::tuple< int64_t, SmallVector< GridAxis > > > detectUnsplitLastAxesInResharding(const Sharding &srcSharding, const Sharding &tgtSharding)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorDim, size_t numUnsplitAxes)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".