'mesh' Dialect
The mesh
dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.
Collective Communication Operations ¶
There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.
The operation attributes mesh
and mesh_axes
specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes mesh_axes
are in the
same group.
For example if we have a device mesh of size 2x3x4x5
and the partition mesh
axes list is [0, 1]
then devices are partitioned into the groups
{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }
.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
mesh_axes
.
The axes are ordered from outer to inner.
If we have an axis list [3, 1]
then device (i, 1, k, 0)
will precede
both devices (i, 0, k, 1)
and (i, 2, k, 0)
.
Operations ¶
See Mesh dialect documentation.
Operations ¶
mesh.all_gather
(mesh::AllGatherOp) ¶
All-gather over a device mesh.
Syntax:
operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
Gathers along the gather_axis
tensor axis.
Example:
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Traits: SameOperandsAndResultElementType, SameOperandsAndResultRank
Interfaces: SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.all_reduce
(mesh::AllReduceOp) ¶
All-reduce over a device mesh.
Syntax:
operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
Traits: SameOperandsAndResultShape
Interfaces: SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::PartialAttr | partial type of a distributed tensorEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.all_to_all
(mesh::AllToAllOp) ¶
All-to-all over a device mesh.
Syntax:
operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
Performs an all-to-all on tensor pieces split along split_axis
.
The resulting pieces are concatenated along concat_axis
on ech device.
Example:
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
Traits: SameOperandsAndResultElementType, SameOperandsAndResultRank
Interfaces: SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
split_axis | ::mlir::IntegerAttr | index attribute |
concat_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.cluster
(mesh::ClusterOp) ¶
Representing a mesh cluster
Syntax:
operation ::= `mesh.cluster` $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
The mesh.cluster operation is a symbol operation that identifies a specific mesh cluster. The operation has three attributes:
sym_name
: This attribute uniquely identifies the name of the mesh cluster. This name serves as a symbolic reference to the cluster throughout the MLIR module, allowing for consistent referencing and easier debugging.rank
: This attribute specifies the number of axes of the cluster. The rank indicates the dimensionality of the mesh cluster and can be used to determine the layout and the addressing space of the computation distributed across the mesh.dim_sizes
: This attribute represents the device assignment along the axes of the cluster. Each integer in the array corresponds to the number of devices along a specific axis. If an integer value is 0, it implies that the number of devices along that axis is unknown. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time.
Example:
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
mesh.cluster @mesh3(rank = 2)
// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
Interfaces: Symbol
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
rank | ::mlir::IntegerAttr | 64-bit signless integer attribute |
dim_sizes | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
mesh.reduce_scatter
(mesh::ReduceScatterOp) ¶
Reduce-scatter over a device mesh.
Syntax:
operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
After the reduction, the result is scattered within each device group.
The tensor is split along scatter_axis
and the pieces distributed
across the device group.
Example:
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)
Result:
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
Traits: SameOperandsAndResultRank
Interfaces: SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::PartialAttr | partial type of a distributed tensorEnum cases:
|
scatter_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.shard
(mesh::ShardOp) ¶
Annotate on how a tensor is sharded across a mesh cluster.
Syntax:
operation ::= `mesh.shard` $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
type($result)
The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has one operand and two attributes:
input
: This operand represents the tensor value that needs to be annotated for sharding.shard
: This attribute is type ofMeshSharding
, which is the core data structure to represent distributed tensor in mesh cluster.annotate_for_users
: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
...
}
// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}
The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
shard | ::mlir::mesh::MeshShardingAttr | Attribute that extends tensor type to distributed tensor type.
|
annotate_for_users | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
src | Multi-dimensional array with a fixed number of dimensions |
Results: ¶
Result | Description |
---|---|
result | Multi-dimensional array with a fixed number of dimensions |
Attributes ¶
MeshShardingAttr ¶
Attribute that extends tensor type to distributed tensor type.
Syntax:
#mesh.shard<
::mlir::SymbolRefAttr, # cluster
::llvm::ArrayRef<::mlir::DenseI32ArrayAttr>, # split_axes
::llvm::ArrayRef<int32_t>, # partial_axes
::mlir::mesh::Partial # partial_type
>
The MeshSharding attribute could be used in the encoding of a
RankedTensorType
or the mesh.shard op. it contains three sub-attributes:
cluster
: this attribute is a SymbolRefAttr that refers to the mesh cluster where the distributed tensor is placed. The symbol must resolve to amesh.cluster
operation.split_axes
: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therank
of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.partial_axes
: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified bypartial_type
.partial_type
: indicates the reduction type of the possible all-reduce op. It has 4 possible values:
partial_sum
: denotes it’s an all-reduce-sumpartial_max
: denotes it’s an all-reduce-maxpartial_min
: denotes it’s an all-reduce-minpartial_generic
: denotes that the all-reduce type is complex and cannot be represented merely by a simple sum, max, or min. The exact reduction computation may be derived from the semantics of the corresponding operation or from the reduction computation IR
Example:
mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
// The tensor is sharded on the first dimension along axis 0 of @mesh0
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
cluster | ::mlir::SymbolRefAttr | cluster placed |
split_axes | ::llvm::ArrayRef<::mlir::DenseI32ArrayAttr> | |
partial_axes | ::llvm::ArrayRef<int32_t> | |
partial_type | ::mlir::mesh::Partial |
PartialAttr ¶
partial type of a distributed tensor
Syntax:
#mesh.partial<
::mlir::mesh::Partial # value
>
Enum cases:
- sum (
Sum
) - max (
Max
) - min (
Min
) - generic (
Generic
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::mesh::Partial | an enum of type Partial |