'mesh' Dialect
The mesh
dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.
Collective Communication Operations ¶
There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.
Device groups ¶
The operation attributes mesh
and mesh_axes
specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes mesh_axes
are in the
same group.
A group is described by its multi-index along the axes outside of mesh_axes
.
For example if we have a device mesh of size 2x3x4x5
and the partition mesh
axes list is [0, 1]
then devices are partitioned into the groups
{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }
.
The device groups would be { (k, m) | 0<=k<4, 0<=m<5 }
.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
mesh_axes
.
The axes are ordered from outer to inner.
If we have an axis list [3, 1]
then device (i, 1, k, 0)
will precede
both devices (i, 0, k, 1)
and (i, 2, k, 0)
.
In-group Device ¶
Some operations like broadcast
, scatter
and send
specify devices in each
device-group.
These devices are represented with their multi-index over the mesh axes that
are not constant within a device group.
These are the axes specified by mesh_axes
attribute.
For Example on a 3D mesh an operation with mesh_axes = [0, 2]
would specify
an in-group device with (i, j)
. Then for each group with index g
on the
second axis, the in-group device would be (i, g, j)
.
Purity ¶
Collectives that involve the whole device group to perform a single operation
are pure. The exceptions are send
and recv
.
There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.
Having the operations as Pure
implies that if an interpreter is to execute
the IR containing the mesh
collectives, all processes would execute the same
line when they reach a pure collective operation.
This requirement stems from the need to be compatible with general optimization
passes like dead code and common sub-expression elimination.
Operations ¶
mesh.all_gather
(mesh::AllGatherOp) ¶
All-gather over a device mesh.
Syntax:
operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
Gathers along the gather_axis
tensor axis.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.all_reduce
(mesh::AllReduceOp) ¶
All-reduce over a device mesh.
Syntax:
operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.all_slice
(mesh::AllSliceOp) ¶
All-slice over a device mesh. This is the inverse of all-gather.
Syntax:
operation ::= `mesh.all_slice` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
Slice along the slice_axis
tensor axis.
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
Input:
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
Result:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
slice_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.all_to_all
(mesh::AllToAllOp) ¶
All-to-all over a device mesh.
Syntax:
operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
Performs an all-to-all on tensor pieces split along split_axis
.
The resulting pieces are concatenated along concat_axis
on ech device.
Example:
mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 | ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
Result:
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
split_axis | ::mlir::IntegerAttr | index attribute |
concat_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.broadcast
(mesh::BroadcastOp) ¶
Broadcast over a device mesh.
Syntax:
operation ::= `mesh.broadcast` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Broadcast the tensor on root
to all devices in each respective group.
The operation broadcasts along mesh axes mesh_axes
.
The root
device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
Input:
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
Output:
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.gather
(mesh::GatherOp) ¶
Gather over a device mesh.
Syntax:
operation ::= `mesh.gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Gathers on device root
along the gather_axis
tensor axis.
root
specifies the coordinates of a device along mesh_axes
.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
Input:
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
Result:
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+
Devices (0, 0)
and (1, 0)
have undefined result.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
gather_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | non-0-ranked.tensor of any type values |
mesh.mesh
(mesh::MeshOp) ¶
Description of a device/process mesh.
Syntax:
operation ::= `mesh.mesh` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
The mesh.mesh operation is a symbol operation that identifies a specific mesh. The operation has three attributes:
sym_name
: This attribute uniquely identifies the name of the mesh. This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging.shape
: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example2x?x4
.
Example:
// A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
mesh.mesh @mesh0(shape = 4x8x12)
// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.mesh @mesh1(shape = 4x?)
// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.mesh @mesh2(shape = ?x4)
// A device mesh with 2 axes, the number of devices along both axes
// is unknown
mesh.mesh @mesh3(shape = ?x?)
Interfaces: Symbol
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
sym_name | ::mlir::StringAttr | string attribute |
shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
mesh.mesh_shape
(mesh::MeshShapeOp) ¶
Get the shape of the mesh.
Syntax:
operation ::= `mesh.mesh_shape` $mesh (`axes` `=` $axes^)?
attr-dict `:` type($result)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
mesh.process_linear_index
(mesh::ProcessLinearIndexOp) ¶
Get the linear index of the current device.
Syntax:
operation ::= `mesh.process_linear_index` `on` $mesh attr-dict `:` type($result)
Example:
%idx = mesh.process_linear_index on @mesh : index
if @mesh
has shape (10, 20, 30)
, a device with multi
index (1, 2, 3)
will have linear index 3 + 30*2 + 20*30*1
.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
Results: ¶
Result | Description |
---|---|
result | index |
mesh.process_multi_index
(mesh::ProcessMultiIndexOp) ¶
Get the multi index of current device along specified mesh axes.
Syntax:
operation ::= `mesh.process_multi_index` `on` $mesh (`axes` `=` $axes^)?
attr-dict `:` type($result)
It is used in the SPMD format of IR.
The axes
mush be non-negative and less than the total number of mesh axes.
If the axes are empty then get the index along all axes.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
Results: ¶
Result | Description |
---|---|
result | variadic of index |
mesh.recv
(mesh::RecvOp) ¶
Send over a device mesh.
Syntax:
operation ::= `mesh.recv` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
Receive from a device within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
source | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
source_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.reduce
(mesh::ReduceOp) ¶
Reduce over a device mesh.
Syntax:
operation ::= `mesh.reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
Reduces on device root
within each device group.
root
specifies the coordinates of a device along mesh_axes
.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
reduction
: Indicates the reduction method.
Example:
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | ranked tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.reduce_scatter
(mesh::ReduceScatterOp) ¶
Reduce-scatter over a device mesh.
Syntax:
operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
After the reduction, the result is scattered within each device group.
The tensor is split along scatter_axis
and the pieces distributed
across the device group.
Example:
mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 | ↓
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
↑
device
(1, 1)
Result:
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultRank
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
reduction | ::mlir::mesh::ReductionKindAttr | Reduction of an iterator/mesh dimension.Enum cases:
|
scatter_axis | ::mlir::IntegerAttr | index attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.scatter
(mesh::ScatterOp) ¶
Scatter over a device mesh.
Syntax:
operation ::= `mesh.scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
For each device group split the input tensor on the root
device along
axis scatter_axis
and scatter the parts across the group devices.
Example:
mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
Input:
device
(0, 1)
↓
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Result:
device
(0, 1)
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
device
(1, 1)
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
scatter_axis | ::mlir::IntegerAttr | index attribute |
root | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
root_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.send
(mesh::SendOp) ¶
Send over a device mesh.
Syntax:
operation ::= `mesh.send` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
Send from one device to another within a device group.
Interfaces: OpAsmOpInterface
, SymbolUserOpInterface
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
destination | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
destination_dynamic | variadic of index |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.shard
(mesh::ShardOp) ¶
Annotate on how a tensor is sharded across a mesh.
Syntax:
operation ::= `mesh.shard` $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
type($result)
The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has one operand and two attributes:
input
: This operand represents the tensor value that needs to be annotated for sharding.shard
: This attribute is type ofMeshSharding
, which is the core data structure to represent distribution of a tensor on a mesh.annotate_for_users
: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.
Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
...
}
// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
}
The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
%1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
...
}
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
shard | ::mlir::mesh::MeshShardingAttr | Attribute that extends tensor type to distributed tensor type.
|
annotate_for_users | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
src | ranked tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
mesh.shift
(mesh::ShiftOp) ¶
Shift over a device mesh.
Syntax:
operation ::= `mesh.shift` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
Within each device group shift along mesh axis shift_axis
by an offset
offset
.
The result on devices that do not have a corresponding source is undefined.
shift_axis
must be one of mesh_axes
.
If the rotate
attribute is present,
instead of a shift a rotation is done.
Example:
mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
Input:
mesh axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
Result:
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
Traits: AlwaysSpeculatableImplTrait
, SameOperandsAndResultElementType
, SameOperandsAndResultShape
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OpAsmOpInterface
, SymbolUserOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | flat symbol reference attribute |
mesh_axes | ::mlir::DenseI16ArrayAttr | i16 dense array attribute |
shift_axis | ::mlir::IntegerAttr | index attribute |
offset | ::mlir::IntegerAttr | 64-bit signless integer attribute |
rotate | ::mlir::UnitAttr | unit attribute |
Operands: ¶
Operand | Description |
---|---|
input | non-0-ranked.tensor of any type values |
Results: ¶
Result | Description |
---|---|
result | ranked tensor of any type values |
Attributes ¶
MeshShardingAttr ¶
Attribute that extends tensor type to distributed tensor type.
Syntax:
#mesh.shard<
::mlir::FlatSymbolRefAttr, # mesh
::llvm::ArrayRef<MeshAxesAttr>, # split_axes
::llvm::ArrayRef<MeshAxis>, # partial_axes
::mlir::mesh::ReductionKind # partial_type
>
The MeshSharding attribute is used in a mesh.shard
operation.
It specifies how a tensor is sharded and distributed across the process
mesh.
mesh
: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to amesh.mesh
operation.split_axes
: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therank
of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.partial_axes
: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified bypartial_type
.partial_type
: indicates the reduction type of the possible all-reduce op. It has 4 possible values:generic
: is not an allowed value inside a shard attribute.
Example:
mesh.mesh @mesh0(shape = 2x2x4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
#mesh.shard<@mesh0, [[]]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0
#mesh.shard<@mesh0, [[0]]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
#mesh.shard<@mesh0, [[0]], partial = max[1]>
// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
mesh | ::mlir::FlatSymbolRefAttr | The mesh on which tensors are sharded. |
split_axes | ::llvm::ArrayRef<MeshAxesAttr> | |
partial_axes | ::llvm::ArrayRef<MeshAxis> | |
partial_type | ::mlir::mesh::ReductionKind |
ReductionKindAttr ¶
Reduction of an iterator/mesh dimension.
Syntax:
#mesh.partial<
::mlir::mesh::ReductionKind # value
>
Enum cases:
- sum (
Sum
) - max (
Max
) - min (
Min
) - product (
Product
) - average (
Average
) - bitwise_and (
BitwiseAnd
) - bitwise_or (
BitwiseOr
) - bitwise_xor (
BitwiseXor
) - generic (
Generic
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::mesh::ReductionKind | an enum of type ReductionKind |