Reshard 2x3
tensor from sharding [[0, 1]]
to sharding [[0, 1]]
on a 2x3
mesh.
unsharded 2x3
tensor
sharded on a 2x3
mesh
sharding = [[0, 1]]
mesh contents:
mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 | |
+----+----+----+ |
| 21 | 22 | 23 | |
+----+----+----+ ↓
Transform into sharding = [[1, 0]]
mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 13 | 22 | |
+----+----+----+ |
| 12 | 21 | 23 | |
+----+----+----+ ↓
Algorithm: Swap contents on devices that have the same linear index in the 2 shardings.
Reshard 2x3
tensor from sharding [[0, 1]]
to sharding [[1]]
on a 2x3
mesh.
unsharded 2x3
tensor
sharded on a 2x3
mesh
sharding = [[0, 1]]
mesh contents:
mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 | |
+----+----+----+ |
| 21 | 22 | 23 | |
+----+----+----+ ↓
Transform into sharding = [[1]]
mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 | |
| 21 | 22 | 23 | |
+----+----+----+ |
| 11 | 12 | 13 | |
| 21 | 22 | 23 | |
+----+----+----+ ↓
Algorithm: All-gather along mesh axis 0.
Reshard 4x6
tensor from sharding [[], [0, 1]]
to sharding [[], [0]]
on a 2x3
mesh.
unsharded 4x6
tensor
11 12 13 14 15 16
21 22 23 24 25 26
sharded on a 2x3
mesh
sharding = [[], [0, 1]]
mesh contents:
mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 | |
| 21 | 22 | 23 | |
+----+----+----+ |
| 14 | 15 | 16 | |
| 24 | 25 | 26 | |
+----+----+----+ ↓
Transform into sharding = [[], [0]]
mesh axis 1
----------->
+----------+----------+ mesh axis 0 |
| 11 12 13 | 11 12 13 | |
| 21 22 23 | 21 22 23 | |
+----------+----------+ |
| 14 15 16 | 14 15 16 | |
| 24 25 26 | 24 25 26 | |
+----------+----------+ ↓
Algorithm: All-gather along mesh axis 1.
Reshard 4x8
tensor from sharding [[0], [1, 2]]
to sharding [[0], [2]]
on a 2x2x2
mesh.
unsharded 4x8
tensor
11 12 13 14 15 16 17 18
21 22 23 24 25 26 27 28
31 32 33 34 35 36 37 38
41 42 43 44 45 46 47 48
sharded on a 2x2x2
mesh
sharding = [[0], [1, 2]]
mesh contents:
mesh axis 2
----------->
+-------+-------+ mesh axis 1 | mesh axis 0 |
| 11 12 | 13 14 | | |
| 21 22 | 23 24 | | |
+-------+-------+ | |
| 15 16 | 17 18 | | |
| 25 26 | 27 28 | | |
+-------+-------+ ↓ |
+-------+-------+ |
| 31 32 | 33 34 | |
| 41 42 | 43 44 | |
+-------+-------+ |
| 35 36 | 37 38 | |
| 45 46 | 47 48 | |
+-------+-------+ ↓
Transform into sharding = [[0], [2]]
mesh axis 2
----------->
+-------------+-------------+ mesh axis 1 | mesh axis 0 |
| 11 12 13 14 | 15 16 17 18 | | |
| 21 22 23 24 | 25 26 27 28 | | |
+-------------+-------------+ | |
| 11 12 13 14 | 15 16 17 18 | | |
| 21 22 23 24 | 25 26 27 28 | | |
+-------------+-------------+ ↓ |
+-------------+-------------+ |
| 31 32 33 34 | 35 36 37 38 | |
| 41 42 43 44 | 45 46 47 48 | |
+-------------+-------------+ |
| 31 32 33 34 | 35 36 37 38 | |
| 41 42 43 44 | 45 46 47 48 | |
+-------------+-------------+ ↓
Algorithm:
Can't be done with just an all-gather along mesh axis 1. Can be handled by multiple resharding transformations [[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
Reshard 6x6
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x3
mesh.
unsharded 6x6
tensor
11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66
sharded on a 2x3
mesh
sharding = [[0], [1]]
mesh axis 1
----------->
+-------+-------+-------+ mesh axis 0 |
| 11 12 | 13 14 | 15 16 | |
| 21 22 | 23 24 | 25 26 | |
| 31 32 | 33 34 | 35 36 | |
+-------+-------+-------+ |
| 41 42 | 43 44 | 45 46 | |
| 51 52 | 53 54 | 55 56 | |
| 61 62 | 63 64 | 65 66 | |
+-------+-------+-------+ ↓
transform to sharding = [[1], [0]]
mesh axis 1
----------->
+----------+----------+----------+ mesh axis 0 |
| 11 12 13 | 31 32 33 | 51 52 53 | |
| 21 22 23 | 41 42 43 | 61 62 63 | |
+----------+----------+----------+ |
| 14 15 16 | 34 35 36 | 54 55 56 | |
| 24 25 26 | 44 45 46 | 64 65 66 | |
+----------+----------+----------+ ↓
mesh axis 0
----------->
+----------+----------+ mesh axis 1 |
| 11 12 13 | 14 15 16 | |
| 21 22 23 | 24 25 26 | |
+----------+----------+ |
| 31 32 33 | 34 35 36 | |
| 41 42 43 | 44 45 46 | |
+----------+----------+ |
| 51 52 53 | 54 55 56 | |
| 61 62 63 | 64 65 66 | |
+----------+----------+ ↓
Algorithm: TODO
Reshard 6x6
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x6
mesh.
unsharded 6x6 tensor
11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66
shard on 2x6
mesh
sharding = [[0], [1]]
mesh axis 1
----------->
+----+----+----+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
+----+----+----+----+----+----+ |
| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
+----+----+----+----+----+----+ ↓
transform to sharding = [[1], [0]]
mesh axis 0
----------->
+----------+----------+ mesh axis 1 |
| 11 12 13 | 14 15 16 | |
+----------+----------+ |
| 21 22 23 | 24 25 26 | |
+----------+----------+ |
| 31 32 33 | 34 35 36 | |
+==========+==========+ |
| 41 42 43 | 44 45 46 | |
+----------+----------+ |
| 51 52 53 | 54 55 56 | |
+----------+----------+ |
| 61 62 63 | 64 65 66 | |
+----------+----------+ ↓
Algorithm: TODO
Reshard KxL tensor from [[0], [1]]
to [[1], [0]]
on MxN
mesh.
M x N
mesh. K x L
tensor t
. d(m, n)
the tensor on device (m, n)
.
sharding = [[0], [1]]
Tensor shard s on each device has size (K ceildiv M, L ceildiv N)
.
d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
substitute
i <- m * (K ceildiv M) + k
j <- n * (L ceildiv N) + l
m -> i floordiv (K ceildiv M)
n -> j floordiv (L ceildiv N)
k -> i % (K ceildiv M)
l -> j % (L ceildiv N)
For the inverse map we get
t[i, j] -> d(
i floordiv (K ceildiv M), j floordiv (L ceildiv N)
)[
i % (K ceildiv M), j % (L ceildiv N)
]
Check:
i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
t[13, 17] = d(
13 floordiv (16 ceildiv 3),
17 floordiv (23 ceilvid 4)
)[
13 % (16 ceildiv 3),
17 % (23 ceilvid 4)
]
= d(
13 floordiv 6,
17 floordiv 6
)[
13 % 6,
17 % 6
]
= d(2, 2)[1, 5]
= t[
2 * (16 ceildiv 3) + 1,
2 * (23 ceildiv 4) + 5
]
= t[
2 * 6 + 1,
2 * 6 + 5
]
= t[13, 17]
sharding = [[1], [0]]
Tensor shard s on each device has size (K ceildiv N, L ceildiv M)
.
d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
substitute
i <- n * (K ceildiv N) + k
j <- m * (L ceildiv M) + l
m -> j floordiv (L ceildiv M)
n -> i floordiv (K ceildiv N)
k -> i % (K ceildiv N)
l -> j % (L ceildiv M)
For the inverse map we get
t[i, j] -> d(
j floordiv (L ceildiv M), i floordiv (K ceildiv N)
)[
i % (K ceildiv N), j % (L ceildiv M)
]
Check:
i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
t[9, 19] = d(
19 floordiv (14 ceildiv 5),
9 floordiv (27 ceildiv 2)
)[
9 % (27 ceildiv 2),
19 % (14 ceildiv 5)
]
= d(
19 floordiv 3,
9 floordiv 14
)[
9 % 14
19 % 3
]
= d(6, 0)[9, 1]
= t[
0 * (27 ceildiv 2) + 9,
6 * (14 ceildiv 5) + 1
]
= t[
0 * 14 + 9,
6 * 3 + 1
]
= t[9, 19]
sharding = [[0], [1]]
d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
sharding = [[1], [0]]
d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
sharding [[0], [1]] -> [[1], [0]]
d1(m, n)
the tensor on device (m, n)
for sharding sharding [[0], [1]]
. d2(m, n)
the tensor on device (m, n)
for sharding sharding [[1], [0]]
.
d1(m, n)[k, l] ->
t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
d2(
(m * (L ceildiv M) + l) floordiv (L ceildiv M),
(n * (K ceildiv N) + k) floordiv (K ceildiv N)
)[
(n * (K ceildiv N) + k) % (K ceildiv N),
(m * (L ceildiv M) + l) % (L ceildiv M)
]
= d2(p, q)[u, v]
We want to copy the the data between devices in slices/tiles. What are the source/target tile coordinates? For a fixed (m, n, p, q)
what is the range of (k, l, u, v)
? TODO
Reshard KxL
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x3
mesh.
Device placement on a 2x3
mesh
11 12 13 <- devices
21 22 23
sharding [[0], [1]]
tensor axis 1
----------->
+----+----+----+ tensor axis 0 |
| 11 | 12 | 13 | |
+----+----+----+ |
| 21 | 22 | 23 | |
+----+----+----+ ↓
transform to sharding [[1], [0]]
tensor axis 1
----------->
+----+----+ tensor axis 0 |
| 11 | 21 | |
+----+----+ |
| 12 | 22 | |
+----+----+ |
| 13 | 23 | |
+----+----+ ↓
+-----------------+--------+--------+-----------------+
| | | |
+ + + +
| 11 | 12 | 13 |
+ + + +
| | | |
+-----------------+--------+--------+-----------------+
| | | |
+ + + +
| 21 | 22 | 23 |
+ + + +
| | | |
+-----------------+--------+--------+-----------------+
+-----------------+--------+--------+-----------------+
| | |
+ 11 + 21 +
| | |
+-----------------+--------+--------+-----------------+
| | |
+ 12 + 22 +
| | |
+-----------------+--------+--------+-----------------+
| | |
+ 13 + 23 +
| | |
+-----------------+--------+--------+-----------------+
+-----------------+--------+--------+-----------------+
| | | | |
+ 11 11 + 12 11 + 12 21 + 13 21 +
| | | | |
+-----------------+--------+--------+-----------------+
| 11 12 | 12 12 | 12 22 | 13 22 |
+-----------------+--------+--------+-----------------+
| 21 12 | 22 12 | 22 22 | 23 22 |
+-----------------+--------+--------+-----------------+
| | | | |
+ 21 13 + 22 13 + 22 23 + 23 23 +
| | | | |
+-----------------+--------+--------+-----------------+
If S
and T
are the source and target shard sizes along some tensor axis. Then we have a period of (S*T)/gcd(S, T)
. Then the cut pattern repeats. TODO
Reshard 6x6
tensor from sharding [[0], []]
to sharding [[], [0]]
on a 3
mesh.
unsharded 6x6
tensor
11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66
sharded on a 3
mesh
sharding = [[0], []]
+-------------------+ mesh axis 0 |
| 11 12 13 14 15 16 | |
| 21 22 23 24 25 26 | |
+-------------------+ |
| 31 32 33 34 35 36 | |
| 41 42 43 44 45 46 | |
+-------------------+ |
| 51 52 53 54 55 56 | |
| 61 62 63 64 65 66 | |
+-------------------+ ↓
transform to sharding = [[], [0]]
mesh axis 0
----------->
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
| 41 42 | 43 44 | 45 46 |
| 51 52 | 53 54 | 55 56 |
| 61 62 | 63 64 | 65 66 |
+-------+-------+-------+
Algorithm:
@code{mlir}
%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
Reshard 4x4
tensor from sharding [[0], [1, 2]]
to sharding [[0, 1], [2]]
on a 2x2x2
mesh.
unsharded 4x4
tensor
11 12 13 14
21 22 23 24
31 32 33 34
41 42 43 44
sharded on a 2x2x2
mesh
sharding = [[0], [1, 2]]
mesh axis 2
----------->
+----+----+ mesh axis 1 | mesh axis 0 |
| 11 | 12 | | |
| 21 | 22 | | |
+----+----+ | |
| 13 | 14 | | |
| 23 | 24 | | |
+----+----+ ↓ |
+----+----+ |
| 31 | 32 | |
| 41 | 42 | |
+----+----+ |
| 33 | 34 | |
| 43 | 44 | |
+----+----+ ↓
transform to sharding = [[0, 1], [2]]
mesh axis 2
----------->
+-------+-------+ mesh axis 1 | mesh axis 0 |
| 11 12 | 13 41 | | |
+-------+-------+ | |
| 21 22 | 23 24 | | |
+-------+-------+ ↓ |
+-------+-------+ |
| 31 32 | 33 34 | |
+-------+-------+ |
| 41 42 | 43 44 | |
+-------+-------+ ↓
Algorithm:
%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
is not enough.
Can be decomposed into
[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
Decomposition into basis of reshardings
We can decompose each resharding into a sequence of basis reshardings. It is not communication efficient in terms of minimizing the data communicated between devices. An efficient approach would be more complicated to implement. Each device has to receive at most as much data as the size of its target sharding tensor.
Basis:
- From replicate to split. ``` [[]] -> [[1]] ``` Extract slices without communication.
- From split to replicate. ``` [[0]] -> [[]] [[0, 1]] -> [[1]] ``` All-gather along mesh axis 0.
- Swap mesh axes order when assigned to the same tensor axis. ``` [[0, 1]] -> [[1, 0]] ``` Swap contents on devices with the same linear index.
- Move mesh axis to different tensor dimension. ``` [[0], []] -> [[], [0]] ``` All-to-all.
Example decomposition of
into
[[0], [1]] -> all-gather along mesh axis 1 ->
[[0], []] -> all-to-all along mesh axis 0 ->
[[], [0]] -> extract slice along mesh axis 1 ->
[[1], [0]]
Example decomposition of
[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
into
[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 ->
[[3], [1, 2], [0]] -> all-gather along mesh axis 3 ->
[[], [1, 2], [0]] -> all-to-all along mesh axis 0 ->
[[0], [1, 2], []]