MLIR

Multi-Level IR Compiler Framework

'mesh' Dialect

The mesh dialect contains a set of attributes, operations and interfaces that are useful for representing sharding and communication on a device mesh cluster.

Collective Communication Operations 

There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.

The operation attributes mesh and mesh_axes specifies a list of device mesh axes that partition the devices into disjoint groups. The collective operation is performed between devices in the same group. Devices that have the same coordinates outside of axes mesh_axes are in the same group. For example if we have a device mesh of size 2x3x4x5 and the partition mesh axes list is [0, 1] then devices are partitioned into the groups { { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }. Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. Device (1, 0, 2, 4) will be in another group. Some collective operations like all-to-all and all-gather care about the order of devices. The order of device in a device group is induced by the order of axes in mesh_axes. The axes are ordered from outer to inner. If we have an axis list [3, 1] then device (i, 1, k, 0) will precede both devices (i, 0, k, 1) and (i, 2, k, 0).

Operations 

See Mesh dialect documentation.

Operations 

source

mesh.all_gather (mesh::AllGatherOp) 

All-gather over a device mesh.

Syntax:

operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
              attr-dict `:` type($input) `->` type($result)

Gathers along the gather_axis tensor axis.

Example:

mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
  : tensor<2x2xi8> -> tensor<2x4xi8>

Input:

                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

gather tensor
axis 1
------------>
+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Traits: SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.all_reduce (mesh::AllReduceOp) 

All-reduce over a device mesh.

Syntax:

operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
              attr-dict `:` type($input) `->` type($result)

The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
  : tensor<3x4xf32> -> tensor<3x4xf64>

Traits: SameOperandsAndResultShape

Interfaces: SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::PartialAttr
partial type of a distributed tensor

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • generic (Generic)

Operands: 

OperandDescription
inputranked tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.all_to_all (mesh::AllToAllOp) 

All-to-all over a device mesh.

Syntax:

operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `split_axis` `=` $split_axis
              `concat_axis` `=` $concat_axis
              attr-dict `:` type($input) `->` type($result)

Performs an all-to-all on tensor pieces split along split_axis. The resulting pieces are concatenated along concat_axis on ech device.

Example:

mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
  split_axis = 0 concat_axis = 0
  : tensor<3x2xi8> -> tensor<3x2xi8>

Input:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+  | split and concat along
| 11 12 | 21 22 | 31 32 |  | tensor axis 0
| 13 14 | 23 24 | 33 34 |  ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+

Result:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+

Traits: SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
split_axis::mlir::IntegerAttrindex attribute
concat_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.cluster (mesh::ClusterOp) 

Representing a mesh cluster

Syntax:

operation ::= `mesh.cluster` $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
              attr-dict

The mesh.cluster operation is a symbol operation that identifies a specific mesh cluster. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the mesh cluster. This name serves as a symbolic reference to the cluster throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. rank: This attribute specifies the number of axes of the cluster. The rank indicates the dimensionality of the mesh cluster and can be used to determine the layout and the addressing space of the computation distributed across the mesh.

  3. dim_sizes: This attribute represents the device assignment along the axes of the cluster. Each integer in the array corresponds to the number of devices along a specific axis. If an integer value is 0, it implies that the number of devices along that axis is unknown. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time.

Example:

// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12 
mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])

// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.cluster @mesh1(rank = 2, dim_sizes = [4])

// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])

// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
mesh.cluster @mesh3(rank = 2)

// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>

Interfaces: Symbol

Attributes: 

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
rank::mlir::IntegerAttr64-bit signless integer attribute
dim_sizes::mlir::DenseI64ArrayAttri64 dense array attribute

mesh.reduce_scatter (mesh::ReduceScatterOp) 

Reduce-scatter over a device mesh.

Syntax:

operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`reduction` `=` $reduction^)?
              `scatter_axis` `=` $scatter_axis
              attr-dict `:` type($input) `->` type($result)

After the reduction, the result is scattered within each device group. The tensor is split along scatter_axis and the pieces distributed across the device group. Example:

mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
  reduction = <max> scatter_axis = 0
  : tensor<3x4xf32> -> tensor<1x4xf64>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |  1  2 |  5  6 |  | axis 0
                 |  3  4 |  7  8 |  ↓
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 |
                 | 11 12 | 15 16 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

+-------+
|  6  8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+

Traits: SameOperandsAndResultRank

Interfaces: SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::PartialAttr
partial type of a distributed tensor

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • generic (Generic)
scatter_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.shard (mesh::ShardOp) 

Annotate on how a tensor is sharded across a mesh cluster.

Syntax:

operation ::= `mesh.shard` $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
              type($result)

The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has one operand and two attributes:

  1. input: This operand represents the tensor value that needs to be annotated for sharding.

  2. shard: This attribute is type of MeshSharding, which is the core data structure to represent distributed tensor in mesh cluster.

  3. annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.

Example:

func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  ...
}

func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  ...
}

// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
  %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
  "op0"(%1) : ...
  "op1"(%2) : ...
  ...
}

The following usages are undefined:

func.func @annotate_on_same_result_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_result_same_value_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_operand_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @result_annotated_after_operand(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
shard::mlir::mesh::MeshShardingAttr
Attribute that extends tensor type to distributed tensor type.
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
  1. cluster: this attribute is a SymbolRefAttr that refers to the mesh cluster where the distributed tensor is placed. The symbol must resolve to a mesh.cluster operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.

  3. partial_axes: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by partial_type.

  4. partial_type: indicates the reduction type of the possible all-reduce op. It has 4 possible values:

  • partial_sum: denotes it’s an all-reduce-sum
  • partial_max: denotes it’s an all-reduce-max
  • partial_min: denotes it’s an all-reduce-min
  • partial_generic: denotes that the all-reduce type is complex and cannot be represented merely by a simple sum, max, or min. The exact reduction computation may be derived from the semantics of the corresponding operation or from the reduction computation IR

Example:

mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
tensor&lt;4x8xf32, #mesh.shard&lt;@mesh0, [[]]&gt;&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0
tensor&lt;4x8xf32, #mesh.shard&lt;@mesh0, [[0]]&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
tensor&lt;4x8xf32, #mesh.shard&lt;@mesh0, [[0], [], [1]]&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
tensor&lt;4x8xf32, #mesh.shard&lt;@mesh0, [[0]], partial = max[1]&gt;

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to &lt;@mesh0, [[0]]&gt; : tensor&lt;4x8xf32&gt;

annotate_for_users::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcMulti-dimensional array with a fixed number of dimensions

Results: 

ResultDescription
resultMulti-dimensional array with a fixed number of dimensions

Attributes 

MeshShardingAttr 

Attribute that extends tensor type to distributed tensor type.

Syntax:

#mesh.shard<
  ::mlir::SymbolRefAttr,   # cluster
  ::llvm::ArrayRef<::mlir::DenseI32ArrayAttr>,   # split_axes
  ::llvm::ArrayRef<int32_t>,   # partial_axes
  ::mlir::mesh::Partial   # partial_type
>

The MeshSharding attribute could be used in the encoding of a RankedTensorType or the mesh.shard op. it contains three sub-attributes:

  1. cluster: this attribute is a SymbolRefAttr that refers to the mesh cluster where the distributed tensor is placed. The symbol must resolve to a mesh.cluster operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.

  3. partial_axes: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by partial_type.

  4. partial_type: indicates the reduction type of the possible all-reduce op. It has 4 possible values:

  • partial_sum: denotes it’s an all-reduce-sum
  • partial_max: denotes it’s an all-reduce-max
  • partial_min: denotes it’s an all-reduce-min
  • partial_generic: denotes that the all-reduce type is complex and cannot be represented merely by a simple sum, max, or min. The exact reduction computation may be derived from the semantics of the corresponding operation or from the reduction computation IR

Example:

mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>

// The tensor is sharded on the first dimension along axis 0 of @mesh0
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>

Parameters: 

ParameterC++ typeDescription
cluster::mlir::SymbolRefAttrcluster placed
split_axes::llvm::ArrayRef<::mlir::DenseI32ArrayAttr>
partial_axes::llvm::ArrayRef<int32_t>
partial_type::mlir::mesh::Partial

PartialAttr 

partial type of a distributed tensor

Syntax:

#mesh.partial<
  ::mlir::mesh::Partial   # value
>

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • generic (Generic)

Parameters: 

ParameterC++ typeDescription
value::mlir::mesh::Partialan enum of type Partial

Attributes