'shard' Dialect
The ‘shard’ dialect defines a set of attributes, operations, and interfaces for working with tensor sharding and device communication.
It’s inspired by [GSPMD](General and Scalable Parallelization for ML Computation Graphs).
Originally, the dialect was called mesh, but it was renamed to better reflect
what it actually does.
Collective Communication Operations ¶
The ‘shard’ dialect includes several collective operations that help coordinate communication between devices arranged in a grid.
If you’re not already familiar with collective operations, this Wikipedia article is a good starting point.
Unlike traditional collectives that are defined in terms of message-passing between explicit buffers on each process, the collectives in this dialect work at a higher level. They’re defined in terms of how data moves across the dimensions of a tensor, and the participating processes are inferred from how the tensor is sharded - not specified manually.
Device Groups ¶
Collective operations run within groups of devices, which are defined
using the grid and grid_axes attributes. These describe
how the full device grid is sliced into smaller groups.
Devices that have the same coordinates outside the listed grid_axes belong
to the same group.
Example: Say your device grid is shaped 2×3×4×5, and you set
grid_axes = [0, 1]. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
So the groups are identified by the coordinates (k, m), and devices like
(1, 0, 2, 3) and (1, 1, 2, 3) are in the same group. But (1, 0, 2, 4)
is in a different group.
For some collectives (like all-to-all), the order of devices in the group
matters. The device order is based on the order of axes in grid_axes, from
outermost to innermost.
Example: If grid_axes = [3, 1], then device (i, 1, k, 0) comes before
(i, 0, k, 1) and (i, 2, k, 0).
In-group Devices ¶
Some operations (like broadcast, scatter, and send) refer to a specific
device within each group. These in-group devices are identified using their
coordinates over the axes listed in grid_axes.
Example: In a 3D grid with grid_axes = [0, 2], an in-group device is specified
as (i, j). If a group is fixed at coordinate g on axis 1, then the full
device index would be (i, g, j).
Purity and Execution Model ¶
Collective operations involve all devices in a group (e.g. all-gather,
all-to-all) and are considered pure. Operations like send and recv are not
collective and are not pure.
The execution model assumes SPMD (Single Program, Multiple Data):
- Every process runs the same program.
- At any collective operation, all processes are in sync.
This means compiler optimizations must treat collective ops carefully. For example, if a collective is removed during optimization, it must be removed from every path and every process that would have participated - otherwise, you’ll get undefined behavior at runtime.
Marking these ops as pure also helps with standard compiler passes like dead code elimination and common subexpression elimination. It ensures that when the program is executed, all devices hit the same line of code at the same time during collectives and so avoid dead-locks.
Operations ¶
shard.all_gather (shard::AllGatherOp) ¶
All-gather over a device grid.
Syntax:
operation ::= `shard.all_gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
Concatenates all tensor slices from a device group defined by grid_axes along
the tensor dimension gather_axis and replicates the result across all devices
in the group.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | non-0-ranked.tensor of any type values |
shard.all_reduce (shard::AllReduceOp) ¶
All-reduce over a device grid.
Syntax:
operation ::= `shard.all_reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
Reduces the input tensor across all devices within the groups defined by
grid_axes, using the specified reduction method. The operation performs an
element-wise reduction over the tensor slices from all devices in each group.
Each device in a group receives a replicated copy of the reduction result.
The accumulation element type is determined by the result type and does not
need to match the input element type. Before performing the reduction, each
input element is converted to the result element type.
Attributes:
reduction: Indicates the reduction method.
Example:
%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
Operands: ¶
| Operand | Description |
|---|---|
input | memref of any type values or ranked tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | memref of any type values or ranked tensor of any type values |
shard.all_slice (shard::AllSliceOp) ¶
All-slice over a device grid.
Syntax:
operation ::= `shard.all_slice` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
Within each device group defined by grid_axes, slices the input tensor along
the slice_axis dimension. It can be viewed as the inverse of an all-gather if
the input data is replicated along the slice_axis.
Each process simply crops its local data to the slice corresponding to its
in-group device index.
Notice: AllSliceOp does not involve any communication between devices and
devices within a group may not have replicated input data.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
Input:
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Result:
slice tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
slice_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | non-0-ranked.tensor of any type values |
shard.all_to_all (shard::AllToAllOp) ¶
All-to-all over a device grid.
Syntax:
operation ::= `shard.all_to_all` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
Each participant logically splits its input along split_axis,
then scatters the resulting pieces across the group defined by grid_axes.
After receiving data pieces from other participants’ scatters,
it concatenates them along concat_axis to produce the final result.
Example:
shard.grid @grid0(shape = 3)
...
%1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
split_axis | ::mlir::IntegerAttr | index attribute |
concat_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | non-0-ranked.tensor of any type values |
shard.broadcast (shard::BroadcastOp) ¶
Broadcast over a device grid.
Syntax:
operation ::= `shard.broadcast` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Copies the input tensor on root to all devices in each group defined by
grid_axes. The root device is defined by its in-group multi-index.
The contents of input tensors on non-root devices are ignored.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.broadcast %0 on @grid0
grid_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
Input:
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
Output:
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.gather (shard::GatherOp) ¶
Gather over a device grid.
Syntax:
operation ::= `shard.gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Concatenates all tensor slices from a device group defined by grid_axes along
the tensor dimension gather_axis and returns the resulting tensor on each
root device. The result on all other (non-root) devices is undefined.
The root device is defined by its in-group multi-index.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.gather %0 on @grid0 grid_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
Input:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+
Devices (0, 0) and (1, 0) have undefined result.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | non-0-ranked.tensor of any type values |
shard.get_sharding (shard::GetShardingOp) ¶
Get the sharding of the given tensor.
Syntax:
operation ::= `shard.get_sharding` $source attr-dict `:` type($source) `->` type($result)
This operation returns the sharding of the given tensor as a Sharding.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
| Operand | Description |
|---|---|
source | ranked tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | sharding definition |
shard.grid (shard::GridOp) ¶
Description of a device/process grid.
Syntax:
operation ::= `shard.grid` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
The shard.grid operation is a symbol operation that identifies a specific grid. The operation has three attributes:
sym_name: This attribute uniquely identifies the name of the grid. This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging.shape: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example2x?x4.
Example:
// A device grid with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
shard.grid @grid0(shape = 4x8x12)
// A device grid with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
shard.grid @grid1(shape = 4x?)
// A device grid with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
shard.grid @grid2(shape = ?x4)
// A device grid with 2 axes, the number of devices along both axes
// is unknown
shard.grid @grid3(shape = ?x?)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), Symbol
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
shard.grid_shape (shard::GridShapeOp) ¶
Get the shape of the grid.
Syntax:
operation ::= `shard.grid_shape` $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
| Result | Description |
|---|---|
result | variadic of index |
shard.neighbors_linear_indices (shard::NeighborsLinearIndicesOp) ¶
For given grid index get the linear indices of the direct neighbor processes along the given split.
Syntax:
operation ::= `shard.neighbors_linear_indices` `on` $grid `[` $device `]`
`split_axes` `=` $split_axes
attr-dict `:` type(results)
Example:
shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
The above returns two indices, 633 and 693, which correspond to the
index of the previous process (1, 1, 3), and the next process
(1, 3, 3) along the split axis 1.
A negative value is returned if there is no neighbor in the respective
direction along the given split_axes.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
device | variadic of index |
Results: ¶
| Result | Description |
|---|---|
neighbor_down | index |
neighbor_up | index |
shard.process_linear_index (shard::ProcessLinearIndexOp) ¶
Get the linear index of the current device.
Syntax:
operation ::= `shard.process_linear_index` `on` $grid attr-dict `:` type($result)
Example:
%idx = shard.process_linear_index on @grid : index
if @grid has shape (10, 20, 30), a device with multi
index (1, 2, 3) will have linear index 3 + 30*2 + 20*30*1.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Results: ¶
| Result | Description |
|---|---|
result | index |
shard.process_multi_index (shard::ProcessMultiIndexOp) ¶
Get the multi index of current device along specified grid axes.
Syntax:
operation ::= `shard.process_multi_index` `on` $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
It is used in the SPMD format of IR.
The axes mush be non-negative and less than the total number of grid axes.
If the axes are empty then get the index along all axes.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
| Result | Description |
|---|---|
result | variadic of index |
shard.recv (shard::RecvOp) ¶
Send over a device grid.
Syntax:
operation ::= `shard.recv` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
Receive tensor from device source, which is defined by its in-group
multi-index. The groups are defined by grid_axes.
The content of input tensor is ignored.
Interfaces: OpAsmOpInterface, SymbolUserOpInterface
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
source | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
source_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.reduce (shard::ReduceOp) ¶
Reduce over a device grid.
Syntax:
operation ::= `shard.reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Reduces the input tensor across all devices within the groups defined by
grid_axes, using the specified reduction method. The operation performs an
element-wise reduction over the tensor slices from all devices in each group.
The reduction result will be returned on the root device of each group.
It is undefined on all other (non-root) devices.
The root device is defined by its in-group multi-index.
The accumulation element type is determined by the result type and does not
need to match the input element type. Before performing the reduction, each
input element is converted to the result element type.
Attributes:
reduction: Indicates the reduction method.
Example:
%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.reduce_scatter (shard::ReduceScatterOp) ¶
Reduce-scatter over a device grid.
Syntax:
operation ::= `shard.reduce_scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
Reduces the input tensor across all devices within the groups defined by
grid_axes using the specified reduction method. The reduction is performed
element-wise across the tensor pieces from all devices in the group.
After reduction, the reduction result is scattered (split and distributed)
across the device group along scatter_axis.
Example:
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<2x2xf32> -> tensor<1x2xf64>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)
Result:
+-------+
| 5 6 | <- devices (0, 0)
+-------+
| 7 8 | <- devices (0, 1)
+-------+
| 13 14 | <- devices (1, 0)
+-------+
| 15 16 | <- devices (1, 1)
+-------+
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::shard::ReductionKindAttr | Reduction of an iterator/grid dimension. |
scatter_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.scatter (shard::ScatterOp) ¶
Scatter over a device grid.
Syntax:
operation ::= `shard.scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
For each device group defined by grid_axes, the input tensor on the root
device is split along axis scatter_axis and distributed across the group.
The content of the input on all other (non-root) devices is ignored.
The root device is defined by its in-group multi-index.
Example:
shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | * * | * * | | axis 0
| * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Result:
device
(0, 1)
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
scatter_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.send (shard::SendOp) ¶
Send over a device grid.
Syntax:
operation ::= `shard.send` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
Send input tensor to device destination, which is defined by its in-group
multi-index. The groups are defined by grid_axes.
Interfaces: OpAsmOpInterface, SymbolUserOpInterface
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
destination | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
destination_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.shard (shard::ShardOp) ¶
Annotate on how a tensor is sharded across a shard.
Syntax:
operation ::= `shard.shard` $src `to` $sharding
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:
input: This operand represents the tensor value that needs to be annotated for sharding.sharding: This attribute is type ofShardingType, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by anshard.shardingoperation.annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
// The first shard.shard op applies to %arg0, the second shard.shard op
// applies for the operand of op0, the third shard.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
%sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
%2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}
The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
%1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
%1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
annotate_for_users | ::mlir::UnitAttr | unit attribute |
Operands: ¶
| Operand | Description |
|---|---|
src | ranked tensor of any type values |
sharding | sharding definition |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.shard_shape (shard::ShardShapeOp) ¶
Get the shard shape for a given process/device.
Syntax:
operation ::= `shard.shard_shape` `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
`sharding` `=` $sharding
`device` `=` custom<DynamicIndexList>($device_dynamic, $device)
attr-dict `:` type(results)
The device/process id is a multi-index of the device/process in the shard.
This operation might be used during partition when the shard shape depends
on (non-constant) values used in shard.sharding.
Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
dims | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
device | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
dims_dynamic | variadic of index |
sharding | sharding definition |
device_dynamic | variadic of index |
Results: ¶
| Result | Description |
|---|---|
result | variadic of index |
shard.sharding (shard::ShardingOp) ¶
Define a sharding of a tensor.
Syntax:
operation ::= `shard.sharding` $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
The Sharding specifies how a tensor is sharded and distributed across the
process shard. It is typically used in a shard.shard operation.
The operation has the following attributes and operands:
grid: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to ashard.gridoperation.split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therankof the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device grid.[Optional] Sizes of halos to be added for each sharded tensor dimension.
halo_sizesis provided as a flattened 1d array of i64s, 2 values for each sharded dimension.halo_sizes = [1, 2]means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end.halo_sizes = [1, 2, 2, 3]defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets[1,2]halos and the seconds gets[2,3]halos.?indicates dynamic halo sizes.[Optional] Offsets for each shard and sharded tensor dimension.
sharded_dims_offsetsis provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that positionihas the exclusive prefix sum for shardi, and since only contiguous sharding is supported, its inclusive prefix sum is at position ‘i+1’.
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
sharded_dims_offsets = [0, 24, 32, 0, 20, 32] means that the first device of
the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. ? indicates dynamic shard dimensions.
halo_sizes and sharded_dims_offsets are mutually exclusive.
Examples:
shard.grid @grid0(shape = 2x2x4)
shard.grid @grid1d_4(shape = 4)
// The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = shard.sharding @grid0 split_axes = [[]]
// The tensor is sharded on the first dimension along axis 0 of @grid0
%sharding1 = shard.sharding @grid0 split_axes = [[0]]
// Could be used for a shard.shard op
%sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>
// The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::shard::GridAxesArrayAttr | |
static_sharded_dims_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
dynamic_sharded_dims_offsets | variadic of 64-bit signless integer |
dynamic_halo_sizes | variadic of 64-bit signless integer |
Results: ¶
| Result | Description |
|---|---|
result | sharding definition |
shard.shift (shard::ShiftOp) ¶
Shift over a device grid.
Syntax:
operation ::= `shard.shift` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
Within each device group defined by grid_axes, shifts input tensors along the
device grid’s axis shift_axis by the specified offset. The shift_axis must
be one of the grid_axes. If the rotate attribute is set, the shift is circular.
That is, the offset wraps around according to the group size along shift_axis.
Otherwise, the results on devices without a corresponding source are undefined.
Example:
shard.grid @grid0(shape = 2x4)
%1 = shard.shift on @grid0 grid_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
Input:
grid axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
Result:
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
grid_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
shift_axis | ::mlir::IntegerAttr | index attribute |
offset | ::mlir::IntegerAttr | 64-bit signless integer attribute |
rotate | ::mlir::UnitAttr | unit attribute |
Operands: ¶
| Operand | Description |
|---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
| Result | Description |
|---|---|
result | ranked tensor of any type values |
shard.update_halo (shard::UpdateHaloOp) ¶
Update halo data.
Syntax:
operation ::= `shard.update_halo` $destination
`on` $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
attr-dict `:` type($result)
This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.
Destination is supposed to be initialized with the local data (not halos).
Assumes all devices hold tensors with same-sized halo data as specified
by source_halo_sizes/static_source_halo_sizes and
destination_halo_sizes/static_destination_halo_sizes in source shard
and destination/result shard.
split_axes specifies for each tensor axis along which grid axes its halo
data is updated.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, DestinationStyleOpInterface, NoMemoryEffect (MemoryEffectOpInterface), SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
grid | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::shard::GridAxesArrayAttr | |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
| Operand | Description |
|---|---|
destination | non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values |
halo_sizes | variadic of 64-bit signless integer |
Results: ¶
| Result | Description |
|---|---|
result | non-0-ranked.memref of any type values or non-0-ranked.tensor of any type values |
Attributes ¶
GridAxesArrayAttr ¶
Syntax:
#shard.axisarray<
::llvm::ArrayRef<GridAxesAttr> # axes
>
Parameters: ¶
| Parameter | C++ type | Description |
|---|---|---|
| axes | ::llvm::ArrayRef<GridAxesAttr> |
ReductionKindAttr ¶
Reduction of an iterator/grid dimension.
Syntax:
#shard.partial<
::mlir::shard::ReductionKind # value
>
Parameters: ¶
| Parameter | C++ type | Description |
|---|---|---|
| value | ::mlir::shard::ReductionKind | an enum of type ReductionKind |
MLIR