'mesh' Dialect
The mesh
dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.
Collective Communication Operations ¶
There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.
Device groups ¶
The operation attributes mesh
and mesh_axes
specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes mesh_axes
are in the
same group.
A group is described by its multi-index along the axes outside of mesh_axes
.
For example if we have a device mesh of size 2x3x4x5
and the partition mesh
axes list is [0, 1]
then devices are partitioned into the groups
{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }
.
The device groups would be { (k, m) | 0<=k<4, 0<=m<5 }
.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
mesh_axes
.
The axes are ordered from outer to inner.
If we have an axis list [3, 1]
then device (i, 1, k, 0)
will precede
both devices (i, 0, k, 1)
and (i, 2, k, 0)
.
In-group Device ¶
Some operations like broadcast
, scatter
and send
specify devices in each
device-group.
These devices are represented with their multi-index over the mesh axes that
are not constant within a device group.
These are the axes specified by mesh_axes
attribute.
For Example on a 3D mesh an operation with mesh_axes = [0, 2]
would specify
an in-group device with (i, j)
. Then for each group with index g
on the
second axis, the in-group device would be (i, g, j)
.
Purity ¶
Collectives that involve the whole device group to perform a single operation
are pure. The exceptions are send
and recv
.
There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.
Having the operations as Pure
implies that if an interpreter is to execute
the IR containing the mesh
collectives, all processes would execute the same
line when they reach a pure collective operation.
This requirement stems from the need to be compatible with general optimization
passes like dead code and common sub-expression elimination.
Operations ¶
mesh.all_gather
(mesh::AllGatherOp) ¶
All-gather over a device mesh.
Syntax:
operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
Gathers along the gather_axis
tensor axis.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.all_reduce
(mesh::AllReduceOp) ¶
All-reduce over a device mesh.
Syntax:
operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.all_slice
(mesh::AllSliceOp) ¶
All-slice over a device mesh. This is the inverse of all-gather.
Syntax:
operation ::= `mesh.all_slice` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
Slice along the slice_axis
tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
Input:
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Result:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
slice_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.all_to_all
(mesh::AllToAllOp) ¶
All-to-all over a device mesh.
Syntax:
operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
Performs an all-to-all on tensor pieces split along split_axis
.
The resulting pieces are concatenated along concat_axis
on ech device.
Example:
mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
split_axis | ::mlir::IntegerAttr | index attribute |
concat_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.broadcast
(mesh::BroadcastOp) ¶
Broadcast over a device mesh.
Syntax:
operation ::= `mesh.broadcast` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Broadcast the tensor on root
to all devices in each respective group.
The operation broadcasts along mesh axes mesh_axes
.
The root
device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
Input:
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
Output:
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.gather
(mesh::GatherOp) ¶
Gather over a device mesh.
Syntax:
operation ::= `mesh.gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Gathers on device root
along the gather_axis
tensor axis.
root
specifies the coordinates of a device along mesh_axes
.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
Input:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+
Devices (0, 0)
and (1, 0)
have undefined result.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.mesh
(mesh::MeshOp) ¶
Description of a device/process mesh.
Syntax:
operation ::= `mesh.mesh` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
The mesh.mesh operation is a symbol operation that identifies a specific mesh. The operation has three attributes:
sym_name
: This attribute uniquely identifies the name of the mesh. This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging.shape
: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example2x?x4
.
Example:
// A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.mesh @mesh0(shape = 4x8x12)
// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.mesh @mesh1(shape = 4x?)
// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.mesh @mesh2(shape = ?x4)
// A device mesh with 2 axes, the number of devices along both axes
// is unknown
mesh.mesh @mesh3(shape = ?x?)
Interfaces: Symbol
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
mesh.mesh_shape
(mesh::MeshShapeOp) ¶
Get the shape of the mesh.
Syntax:
operation ::= `mesh.mesh_shape` $mesh (`axes` `=` $axes^)?
attr-dict `:` type($result)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
mesh.process_linear_index
(mesh::ProcessLinearIndexOp) ¶
Get the linear index of the current device.
Syntax:
operation ::= `mesh.process_linear_index` `on` $mesh attr-dict `:` type($result)
Example:
%idx = mesh.process_linear_index on @mesh : index
if @mesh
has shape (10, 20, 30)
, a device with multi
index (1, 2, 3)
will have linear index 3 + 30*2 + 20*30*1
.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Results: ¶
Result | Description |
---|---|
result | index |
mesh.process_multi_index
(mesh::ProcessMultiIndexOp) ¶
Get the multi index of current device along specified mesh axes.
Syntax:
operation ::= `mesh.process_multi_index` `on` $mesh (`axes` `=` $axes^)?
attr-dict `:` type($result)
It is used in the SPMD format of IR.
The axes
mush be non-negative and less than the total number of mesh axes.
If the axes are empty then get the index along all axes.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
mesh.recv
(mesh::RecvOp) ¶
Send over a device mesh.
Syntax:
operation ::= `mesh.recv` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
Receive from a device within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
source | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
source_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.reduce
(mesh::ReduceOp) ¶
Reduce over a device mesh.
Syntax:
operation ::= `mesh.reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Reduces on device root
within each device group.
root
specifies the coordinates of a device along mesh_axes
.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.reduce_scatter
(mesh::ReduceScatterOp) ¶
Reduce-scatter over a device mesh.
Syntax:
operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
After the reduction, the result is scattered within each device group.
The tensor is split along scatter_axis
and the pieces distributed
across the device group.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)
Result:
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
scatter_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.scatter
(mesh::ScatterOp) ¶
Scatter over a device mesh.
Syntax:
operation ::= `mesh.scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
For each device group split the input tensor on the root
device along
axis scatter_axis
and scatter the parts across the group devices.
Example:
mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Result:
device
(0, 1)
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
scatter_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.send
(mesh::SendOp) ¶
Send over a device mesh.
Syntax:
operation ::= `mesh.send` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
Send from one device to another within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
destination | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
destination_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.shard
(mesh::ShardOp) ¶
Annotate on how a tensor is sharded across a mesh.
Syntax:
operation ::= `mesh.shard` $src `to` $sharding
(`annotate_for_users` $annotate_for_users^)?
attr-dict `:` type($result)
The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has two operands and two optional attributes:
input
: This operand represents the tensor value that needs to be annotated for sharding.sharding
: This attribute is type ofMeshShardingType
, which is the core data structure to represent distribution of a tensor on a mesh. it is typically defiend by anmesh.sharding
operation.annotate_for_users
: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
%sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
%1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
%2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}
The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
%0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
%1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
%1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
%sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
%0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
annotate_for_users | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
src | ranked tensor of any type values |
sharding | sharding definition |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.shard_shape
(mesh::ShardShapeOp) ¶
Get the shard shape of a given process/device.
Syntax:
operation ::= `mesh.shard_shape` custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
The device/process id is a linearized id of the device/process in the mesh.
This operation might be used during spmdization when the shard shape depends
on (non-constant) values used in mesh.sharding
.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
sharding | sharding definition |
device | index |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
mesh.sharding
(mesh::ShardingOp) ¶
Define a sharding of a tensor.
Syntax:
operation ::= `mesh.sharding` $mesh
`split_axes` `=` $split_axes
(`partial` `=` $partial_type $partial_axes^)?
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
attr-dict `:` type($result)
The MeshSharding specifies how a tensor is sharded and distributed across the
process mesh. It is typically used in a mesh.shard
operation.
The operation has the follwing attributes and operands:
mesh
: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to amesh.mesh
operation.split_axes
: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therank
of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.[Optional]
partial_axes
: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified bypartial_type
.[Optional]
partial_type
: indicates the reduction type of the possible all-reduce op. It has 4 possible values:generic
: is not an allowed value inside a shard attribute.[Optional] Sizes of halos to be added for each sharded tensor dimension.
halo_sizes
is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.halo_sizes
= [1, 2] means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end.halo_sizes
= [1, 2, 2, 3] defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.?
indicates dynamic halo sizes.[Optional] Sizes of sharded dimensions of each shard.
sharded_dims_sizes
is provided as a flattened 1d array of i64s: for each device of the device-mesh one value for each sharded tensor dimension. Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,sharded_dims_sizes
= [16, 8, 16, 24] means that the first device of the device-mesh will get a shard of shape 16x8x32 and the second device will get a shard of shape 16x24x32.?
indicates dynamic shard dimensions.
halo_sizes
and sharded_dims_sizes
are mutually exclusive.
Examples:
mesh.mesh @mesh0(shape = 2x2x4)
mesh.mesh @mesh1d_4(shape = 4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = mesh.sharding @mesh0 split_axes = [[]]
// The tensor is sharded on the first dimension along axis 0 of @mesh0
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
%sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]
// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
%sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]
// Could be used for a mesh.shard op
%sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::mesh::MeshAxesArrayAttr | |
partial_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
partial_type | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
static_sharded_dims_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
dynamic_sharded_dims_sizes | variadic of 64-bit signless integer |
dynamic_halo_sizes | variadic of 64-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | sharding definition |
mesh.shift
(mesh::ShiftOp) ¶
Shift over a device mesh.
Syntax:
operation ::= `mesh.shift` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
Within each device group shift along mesh axis shift_axis
by an offset
offset
.
The result on devices that do not have a corresponding source is undefined.
shift_axis
must be one of mesh_axes
.
If the rotate
attribute is present,
instead of a shift a rotation is done.
Example:
mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
Input:
mesh axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
Result:
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
shift_axis | ::mlir::IntegerAttr | index attribute |
offset | ::mlir::IntegerAttr | 64-bit signless integer attribute |
rotate | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.update_halo
(mesh::UpdateHaloOp) ¶
Update halo data.
Syntax:
operation ::= `mesh.update_halo` $input `on` $mesh
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`target_halo_sizes` `=` $target_halo_sizes^)?
attr-dict `:` type($input)
This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.
Assumes all devices hold tensors with same-sized halo data as specified
by dynamic/static_halo_sizes
.
split_axes
specifies for each tensor axis along which mesh axes its halo
data is updated.
Optionally resizes to new halo sizes target_halo_sizes
.
Interfaces: SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
split_axes | ::mlir::mesh::MeshAxesArrayAttr | |
static_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
target_halo_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.memref of any type values |
dynamic_halo_sizes | variadic of 64-bit signless integer |
Attributes ¶
MeshAxesArrayAttr ¶
Syntax:
#mesh.axisarray<
::llvm::ArrayRef<MeshAxesAttr> # axes
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
axes | ::llvm::ArrayRef<MeshAxesAttr> |
ReductionKindAttr ¶
Reduction of an iterator/mesh dimension.
Syntax:
#mesh.partial<
::mlir::mesh::ReductionKind # value
>
Enum cases:
- sum (
Sum
) - max (
Max
) - min (
Min
) - product (
Product
) - average (
Average
) - bitwise_and (
BitwiseAnd
) - bitwise_or (
BitwiseOr
) - bitwise_xor (
BitwiseXor
) - generic (
Generic
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::mesh::ReductionKind | an enum of type ReductionKind |