|
| template<typename SourceAxes, typename TargetAxes> |
| static bool | mlir::shard::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
| static Sharding | mlir::shard::targetShardingInSplitLastAxis (MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | mlir::shard::splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< int64_t, GridAxis > > | mlir::shard::detectSplitLastAxisInResharding (const Sharding &sourceSharding, const Sharding &targetSharding) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | mlir::shard::trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, SmallVector< GridAxis > > > | mlir::shard::detectUnsplitLastAxesInResharding (const Sharding &srcSharding, const Sharding &tgtSharding) |
| static Sharding | mlir::shard::targetShardingInUnsplitLastAxes (MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorDim, size_t numUnsplitAxes) |
| static ShapedType | mlir::shard::allGatherResultTypeInUnsplitLastAxes (ShapedType sourceType, int64_t splitTensorDim, ArrayRef< int64_t > gridShape, ArrayRef< GridAxis > unsplitAxes) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | mlir::shard::unsplitLastAxesInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorDim, ArrayRef< GridAxis > unsplitAxes) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | mlir::shard::tryUnsplitLastAxesInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, int64_t, GridAxis > > | mlir::shard::detectMoveLastSplitAxisInResharding (const Sharding &sourceSharding, const Sharding &targetSharding) |
| static Sharding | mlir::shard::targetShardingInMoveLastAxis (MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static ShapedType | mlir::shard::allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | mlir::shard::moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | mlir::shard::tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | mlir::shard::tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | mlir::shard::reshard (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| TypedValue< ShapedType > | mlir::shard::reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
| TypedValue< ShapedType > | mlir::shard::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
| void | mlir::shard::reshardingRegisterDependentDialects (DialectRegistry ®istry) |
| static SmallVector< Type > | mlir::shard::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
| static LogicalResult | mlir::shard::partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static std::vector< Sharding > | mlir::shard::getOperandShardings (Operation &op) |
| static std::vector< Sharding > | mlir::shard::getResultShardings (Operation &op) |
| static LogicalResult | mlir::shard::partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | mlir::shard::checkFullyAnnotated (Block &block) |
| static LogicalResult | mlir::shard::checkFullyAnnotated (Operation *op) |
| static LogicalResult | mlir::shard::partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | mlir::shard::partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | mlir::shard::partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |