'shard' Dialect
The ‘shard’ dialect defines a set of attributes, operations, and interfaces for working with tensor sharding and device communication.
It’s inspired by [GSPMD](General and Scalable Parallelization for ML Computation Graphs).
Originally, the dialect was called mesh
, but it was renamed to better reflect
what it actually does.
Collective Communication Operations ¶
The ‘shard’ dialect includes several collective operations that help coordinate communication between devices arranged in a grid.
If you’re not already familiar with collective operations, this Wikipedia article is a good starting point.
Unlike traditional collectives that are defined in terms of message-passing between explicit buffers on each process, the collectives in this dialect work at a higher level. They’re defined in terms of how data moves across the dimensions of a tensor, and the participating processes are inferred from how the tensor is sharded - not specified manually.
Device Groups ¶
Each collective operation runs within a group of devices. You define groups
using the grid
and grid_axes
attributes, which describe how to slice the
full device grid into smaller groups.
Devices that have the same coordinates outside the listed grid_axes
belong
to the same group.
Example: Say your device grid is shaped 2×3×4×5
, and you set
grid_axes = [0, 1]
. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
So the groups are identified by the coordinates (k, m)
, and devices like
(1, 0, 2, 3)
and (1, 1, 2, 3)
are in the same group. But (1, 0, 2, 4)
is in a different group.
For some collectives (like all-to-all
), the order of devices in the group
matters. The device order is based on the order of axes in grid_axes
, from
outermost to innermost.
Example: If grid_axes = [3, 1]
, then device (i, 1, k, 0)
comes before
(i, 0, k, 1)
and (i, 2, k, 0)
.
In-group Devices ¶
Some operations (like broadcast
, scatter
, and send
) refer to a specific
device within each group. These in-group devices are identified using their
coordinates over the axes listed in grid_axes
.
Example: In a 3D grid with grid_axes = [0, 2]
, an in-group device is specified
as (i, j)
. If a group is fixed at coordinate g
on axis 1, then the full
device index would be (i, g, j)
.
Purity and Execution Model ¶
Collective operations involve all devices in a group (e.g. all-gather
,
all-to-all
) and are considered pure. Operations like send
and recv
are not
collective and are not pure.
The execution model assumes SPMD (Single Program, Multiple Data):
- Every process runs the same program.
- At any collective operation, all processes are in sync.
This means compiler optimizations must treat collective ops carefully. For example, if a collective is removed during optimization, it must be removed from every path and every process that would have participated - otherwise, you’ll get undefined behavior at runtime.
Marking these ops as pure also helps with standard compiler passes like dead code elimination and common subexpression elimination. It ensures that when the program is executed, all devices hit the same line of code at the same time during collectives and so avoid dead-locks.
Operations ¶
shard.all_gather
(shard::AllGatherOp) ¶
All-gather over a device grid.
Syntax:
operation ::= `shard.all_gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
Gathers along the gather_axis
tensor axis.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
shard.all_reduce
(shard::AllReduceOp) ¶
All-reduce over a device grid.
Syntax:
operation ::= `shard.all_reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
Operands: ¶
Operand | Description |
---|---|
input | memref of any type values or ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | memref of any type values or ranked tensor of any type values |
shard.all_slice
(shard::AllSliceOp) ¶
All-slice over a device grid. This is the inverse of all-gather.
Syntax:
operation ::= `shard.all_slice` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
Slice along the slice_axis
tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
Input:
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Result:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
slice_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
shard.all_to_all
(shard::AllToAllOp) ¶
All-to-all over a device grid.
Syntax:
operation ::= `shard.all_to_all` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
Performs an all-to-all on tensor pieces split along split_axis
.
The resulting pieces are concatenated along concat_axis
on ech device.
Example:
shard.grid @grid0(shape = 3)
...
%1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
split_axis | ::mlir::IntegerAttr | index attribute |
concat_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
shard.broadcast
(shard::BroadcastOp) ¶
Broadcast over a device grid.
Syntax:
operation ::= `shard.broadcast` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Broadcast the tensor on root
to all devices in each respective group.
The operation broadcasts along grid axes grid_axes
.
The root
device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.broadcast %0 on @grid0
grid_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
Input:
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
Output:
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.gather
(shard::GatherOp) ¶
Gather over a device grid.
Syntax:
operation ::= `shard.gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Gathers on device root
along the gather_axis
tensor axis.
root
specifies the coordinates of a device along grid_axes
.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.gather %0 on @grid0 grid_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
Input:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+
Devices (0, 0)
and (1, 0)
have undefined result.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
shard.get_sharding
(shard::GetShardingOp) ¶
Get the sharding of the given tensor.
Syntax:
operation ::= `shard.get_sharding` $source attr-dict `:` type($source) `->` type($result)
This operation returns the sharding of the given tensor as a Sharding.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
source | ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | sharding definition |
shard.grid
(shard::GridOp) ¶
Description of a device/process grid.
Syntax:
operation ::= `shard.grid` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
The shard.grid operation is a symbol operation that identifies a specific grid. The operation has three attributes:
sym_name
: This attribute uniquely identifies the name of the grid. This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging.shape
: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example2x?x4
.
Example:
// A device grid with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
shard.grid @grid0(shape = 4x8x12)
// A device grid with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
shard.grid @grid1(shape = 4x?)
// A device grid with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
shard.grid @grid2(shape = ?x4)
// A device grid with 2 axes, the number of devices along both axes
// is unknown
shard.grid @grid3(shape = ?x?)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, Symbol
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
shard.grid_shape
(shard::GridShapeOp) ¶
Get the shape of the grid.
Syntax:
operation ::= `shard.grid_shape` $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
shard.neighbors_linear_indices
(shard::NeighborsLinearIndicesOp) ¶
For given grid index get the linear indices of the direct neighbor processes along the given split.
Syntax:
operation ::= `shard.neighbors_linear_indices` `on` $grid `[` $device `]`
`split_axes` `=` $split_axes
attr-dict `:` type(results)
Example:
shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
The above returns two indices, 633
and 693
, which correspond to the
index of the previous process (1, 1, 3)
, and the next process
(1, 3, 3) along the split axis
1`.
A negative value is returned if there is no neighbor in the respective
direction along the given split_axes
.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
device | variadic of index |
Results: ¶
Result | Description |
---|---|
neighbor_down | index |
neighbor_up | index |
shard.process_linear_index
(shard::ProcessLinearIndexOp) ¶
Get the linear index of the current device.
Syntax:
operation ::= `shard.process_linear_index` `on` $grid attr-dict `:` type($result)
Example:
%idx = shard.process_linear_index on @grid : index
if @grid
has shape (10, 20, 30)
, a device with multi
index (1, 2, 3)
will have linear index 3 + 30*2 + 20*30*1
.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Results: ¶
Result | Description |
---|---|
result | index |
shard.process_multi_index
(shard::ProcessMultiIndexOp) ¶
Get the multi index of current device along specified grid axes.
Syntax:
operation ::= `shard.process_multi_index` `on` $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
It is used in the SPMD format of IR.
The axes
mush be non-negative and less than the total number of grid axes.
If the axes are empty then get the index along all axes.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
shard.recv
(shard::RecvOp) ¶
Send over a device grid.
Syntax:
operation ::= `shard.recv` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
Receive from a device within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
source | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
source_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.reduce
(shard::ReduceOp) ¶
Reduce over a device grid.
Syntax:
operation ::= `shard.reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Reduces on device root
within each device group.
root
specifies the coordinates of a device along grid_axes
.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.reduce_scatter
(shard::ReduceScatterOp) ¶
Reduce-scatter over a device grid.
Syntax:
operation ::= `shard.reduce_scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
After the reduction, the result is scattered within each device group.
The tensor is split along scatter_axis
and the pieces distributed
across the device group.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)
Result:
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
scatter_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.scatter
(shard::ScatterOp) ¶
Scatter over a device grid.
Syntax:
operation ::= `shard.scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
For each device group split the input tensor on the root
device along
axis scatter_axis
and scatter the parts across the group devices.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Result:
device
(0, 1)
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
scatter_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.send
(shard::SendOp) ¶
Send over a device grid.
Syntax:
operation ::= `shard.send` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
Send from one device to another within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
destination | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
destination_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.shard
(shard::ShardOp) ¶
Annotate on how a tensor is sharded across a shard.
Syntax:
operation ::= `shard.shard` $src `to` $sharding
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:
input
: This operand represents the tensor value that needs to be annotated for sharding.sharding
: This attribute is type ofShardingType
, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by anshard.sharding
operation.annotate_for_users
: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
// The first shard.shard op applies to %arg0, the second shard.shard op
// applies for the operand of op0, the third shard.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
%sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
%2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}
The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
%1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
%1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
annotate_for_users | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
src | ranked tensor of any type values |
sharding | sharding definition |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.shard_shape
(shard::ShardShapeOp) ¶
Get the shard shape for a given process/device.
Syntax:
operation ::= `shard.shard_shape` `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
`sharding` `=` $sharding
`device` `=` custom<DynamicIndexList>($device_dynamic, $device)
attr-dict `:` type(results)
The device/process id is a multi-index of the device/process in the shard.
This operation might be used during partition when the shard shape depends
on (non-constant) values used in shard.sharding
.
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
dims | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
device | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
dims_dynamic | variadic of index |
sharding | sharding definition |
device_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
shard.sharding
(shard::ShardingOp) ¶
Define a sharding of a tensor.
Syntax:
operation ::= `shard.sharding` $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
The Sharding specifies how a tensor is sharded and distributed across the
process shard. It is typically used in a shard.shard
operation.
The operation has the following attributes and operands:
grid
: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to ashard.grid
operation.split_axes
: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therank
of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device grid.[Optional] Sizes of halos to be added for each sharded tensor dimension.
halo_sizes
is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.halo_sizes = [1, 2]
means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end.halo_sizes = [1, 2, 2, 3]
defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets[1,2]
halos and the seconds gets[2,3]
halos.?
indicates dynamic halo sizes.[Optional] Offsets for each shard and sharded tensor dimension.
sharded_dims_offsets
is provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that positioni
has the exclusive prefix sum for shardi
, and since only contiguous sharding is supported, its inclusive prefix sum is at position ‘i+1’.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
sharded_dims_offsets
= [0, 24, 32, 0, 20, 32] means that the first device of
the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. ?
indicates dynamic shard dimensions.
halo_sizes
and sharded_dims_offsets
are mutually exclusive.
Examples:
shard.grid @grid0(shape = 2x2x4)
shard.grid @grid1d_4(shape = 4)
// The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = shard.sharding @grid0 split_axes = [[]]
// The tensor is sharded on the first dimension along axis 0 of @grid0
%sharding1 = shard.sharding @grid0 split_axes = [[0]]
// Could be used for a shard.shard op
%sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>
// The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::shard::GridAxesArrayAttr | |
static_sharded_dims_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
dynamic_sharded_dims_offsets | variadic of 64-bit signless integer |
dynamic_halo_sizes | variadic of 64-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | sharding definition |
shard.shift
(shard::ShiftOp) ¶
Shift over a device grid.
Syntax:
operation ::= `shard.shift` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
Within each device group shift along grid axis shift_axis
by an offset
offset
.
The result on devices that do not have a corresponding source is undefined.
shift_axis
must be one of grid_axes
.
If the rotate
attribute is present,
instead of a shift a rotation is done.
Example:
shard.grid @grid0(shape = 2x4)
%1 = shard.shift on @grid0 grid_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
Input:
grid axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
Result:
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
shift_axis | ::mlir::IntegerAttr | index attribute |
offset | ::mlir::IntegerAttr | 64-bit signless integer attribute |
rotate | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
shard.update_halo
(shard::UpdateHaloOp) ¶
Update halo data.
Syntax:
operation ::= `shard.update_halo` $destination
`on` $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
attr-dict `:` type($result)
This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.
Destination is supposed to be initialized with the local data (not halos).
Assumes all devices hold tensors with same-sized halo data as specified
by source_halo_sizes/static_source_halo_sizes
and
destination_halo_sizes/static_destination_halo_sizes
in source shard
and destination/result shard.
split_axes
specifies for each tensor axis along which grid axes its halo
data is updated.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, DestinationStyleOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::shard::GridAxesArrayAttr | |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
destination | non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values |
halo_sizes | variadic of 64-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values |
Attributes ¶
GridAxesArrayAttr ¶
Syntax:
#shard.axisarray<
::llvm::ArrayRef<GridAxesAttr> # axes
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
axes | ::llvm::ArrayRef<GridAxesAttr> |
ReductionKindAttr ¶
Reduction of an iterator/grid dimension.
Syntax:
#shard.partial<
::mlir::shard::ReductionKind # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::shard::ReductionKind | an enum of type ReductionKind |