MLIR  19.0.0git
Spmdization.h
Go to the documentation of this file.
1 //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
11 
15 
16 namespace mlir {
17 namespace mesh {
18 
19 // Insert resharding spmdization of the value `sourceShardValue`
20 // from sharding `source` to sharding `target`.
21 // `sourceShardValue` is the already sharded value according to `source`.
22 //
23 // Example
24 //
25 // ```mlir
26 // mesh.mesh @mesh_1d(shape = 2)
27 // ...
28 // %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
29 // %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
30 // ```
31 //
32 // Will result in
33 //
34 // ```mlir
35 // %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
36 // tensor<1xi8> -> tensor<2xi8>
37 // ```
38 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
39  ShardOp target,
40  TypedValue<ShapedType> sourceShardValue);
41 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
42  ShardOp target,
43  TypedValue<ShapedType> sourceShardValue,
44  SymbolTableCollection &symbolTableCollection);
45 
46 void reshardingRegisterDependentDialects(DialectRegistry &registry);
47 
48 } // namespace mesh
49 } // namespace mlir
50 
51 #endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry &registry)
Include the generated interface declarations.