|
template<typename SourceAxes , typename TargetAxes > |
static bool | mlir::mesh::arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
|
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | mlir::mesh::handlePartialAxesDuringResharding (OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard) |
|
static MeshShardingAttr | mlir::mesh::targetShardingInSplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
|
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | mlir::mesh::splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
|
static std::optional< std::tuple< int64_t, MeshAxis > > | mlir::mesh::detectSplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
|
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | mlir::mesh::trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard) |
|
static std::optional< std::tuple< int64_t, MeshAxis > > | mlir::mesh::detectUnsplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
|
static MeshShardingAttr | mlir::mesh::targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis) |
|
static ShapedType | mlir::mesh::allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
|
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | mlir::mesh::unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
|
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | mlir::mesh::tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
|
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > | mlir::mesh::detectMoveLastSplitAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
|
static MeshShardingAttr | mlir::mesh::targetShardingInMoveLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
|
static ShapedType | mlir::mesh::allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
|
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | mlir::mesh::moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis) |
|
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | mlir::mesh::tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
|
static TypedValue< ShapedType > | mlir::mesh::reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
|
TypedValue< ShapedType > | mlir::mesh::reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
|
TypedValue< ShapedType > | mlir::mesh::reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
|
TypedValue< ShapedType > | mlir::mesh::reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
|
void | mlir::mesh::reshardingRegisterDependentDialects (DialectRegistry ®istry) |
|
SmallVector< Type > | mlir::mesh::shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
|
static LogicalResult | mlir::mesh::spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
|
static SmallVector< MeshShardingAttr > | mlir::mesh::getOperandShardings (Operation &op) |
|
static SmallVector< MeshShardingAttr > | mlir::mesh::getResultShardings (Operation &op) |
|
static LogicalResult | mlir::mesh::spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
|
static LogicalResult | mlir::mesh::spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
|
static LogicalResult | mlir::mesh::spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
|
static LogicalResult | mlir::mesh::spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection) |
|