MLIR  22.0.0git
Partition.h
Go to the documentation of this file.
1 //===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
10 #define MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
11 
14 
15 namespace mlir {
16 namespace shard {
17 
18 // Insert resharding partition of the value `sourceShardValue`
19 // from sharding `source` to sharding `target`.
20 // `sourceShardValue` is the already sharded value according to `source`.
21 //
22 // Example
23 //
24 // ```mlir
25 // shard.grid @grid_1d(shape = 2)
26 // ...
27 // %1 = shard.shard %0 to <@grid_1d, [[0]]> : tensor<2xi8>
28 // %2 = shard.shard %1 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8>
29 // ```
30 //
31 // Will result in
32 //
33 // ```mlir
34 // %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 :
35 // tensor<1xi8> -> tensor<2xi8>
36 // ```
37 TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
38  ShardOp target,
39  TypedValue<ShapedType> sourceShardValue);
40 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
41  ShardOp target,
42  TypedValue<ShapedType> sourceShardValue,
43  SymbolTableCollection &symbolTableCollection);
44 
45 void reshardingRegisterDependentDialects(DialectRegistry &registry);
46 
47 } // namespace shard
48 } // namespace mlir
49 
50 #endif // MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
shard::GridOp GridOp
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
Definition: Partition.cpp:505
void reshardingRegisterDependentDialects(DialectRegistry &registry)
Definition: Partition.cpp:526
Include the generated interface declarations.