MLIR

Multi-Level IR Compiler Framework

'mesh' Dialect

The mesh dialect contains a set of attributes, operations and interfaces that are useful for representing sharding and communication on a device mesh cluster.

Collective Communication Operations 

There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.

Device groups 

The operation attributes mesh and mesh_axes specifies a list of device mesh axes that partition the devices into disjoint groups. The collective operation is performed between devices in the same group. Devices that have the same coordinates outside of axes mesh_axes are in the same group. A group is described by its multi-index along the axes outside of mesh_axes. For example if we have a device mesh of size 2x3x4x5 and the partition mesh axes list is [0, 1] then devices are partitioned into the groups { { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }. The device groups would be { (k, m) | 0<=k<4, 0<=m<5 }. Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. Device (1, 0, 2, 4) will be in another group. Some collective operations like all-to-all and all-gather care about the order of devices. The order of device in a device group is induced by the order of axes in mesh_axes. The axes are ordered from outer to inner. If we have an axis list [3, 1] then device (i, 1, k, 0) will precede both devices (i, 0, k, 1) and (i, 2, k, 0).

In-group Device 

Some operations like broadcast, scatter and send specify devices in each device-group. These devices are represented with their multi-index over the mesh axes that are not constant within a device group. These are the axes specified by mesh_axes attribute.

For Example on a 3D mesh an operation with mesh_axes = [0, 2] would specify an in-group device with (i, j). Then for each group with index g on the second axis, the in-group device would be (i, g, j).

Purity 

Collectives that involve the whole device group to perform a single operation are pure. The exceptions are send and recv.

There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.

Having the operations as Pure implies that if an interpreter is to execute the IR containing the mesh collectives, all processes would execute the same line when they reach a pure collective operation. This requirement stems from the need to be compatible with general optimization passes like dead code and common sub-expression elimination.

Operations 

source

mesh.all_gather (mesh::AllGatherOp) 

All-gather over a device mesh.

Syntax:

operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
              attr-dict `:` type($input) `->` type($result)

Gathers along the gather_axis tensor axis.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
  : tensor<2x2xi8> -> tensor<2x4xi8>

Input:

                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

gather tensor
axis 1
------------>
+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.all_reduce (mesh::AllReduceOp) 

All-reduce over a device mesh.

Syntax:

operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
              attr-dict `:` type($input) `->` type($result)

The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
  : tensor<3x4xf32> -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)

Operands: 

OperandDescription
inputranked tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.all_slice (mesh::AllSliceOp) 

All-slice over a device mesh. This is the inverse of all-gather.

Syntax:

operation ::= `mesh.all_slice` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
              attr-dict `:` type($input) `->` type($result)

Slice along the slice_axis tensor axis. This operation can be thought of as the inverse of all-gather. Technically, it is not required that all processes have the same input tensor. Each process will slice a piece of its local tensor based on its in-group device index. The operation does not communicate data between devices.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
  : tensor<2x4xi8> -> tensor<2x2xi8>

Input:

+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Result:

gather tensor
axis 1
------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
slice_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.all_to_all (mesh::AllToAllOp) 

All-to-all over a device mesh.

Syntax:

operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `split_axis` `=` $split_axis
              `concat_axis` `=` $concat_axis
              attr-dict `:` type($input) `->` type($result)

Performs an all-to-all on tensor pieces split along split_axis. The resulting pieces are concatenated along concat_axis on ech device.

Example:

mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
  split_axis = 0 concat_axis = 0
  : tensor<3x2xi8> -> tensor<3x2xi8>

Input:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+  | split and concat along
| 11 12 | 21 22 | 31 32 |  | tensor axis 0
| 13 14 | 23 24 | 33 34 |  ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+

Result:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
split_axis::mlir::IntegerAttrindex attribute
concat_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.broadcast (mesh::BroadcastOp) 

Broadcast over a device mesh.

Syntax:

operation ::= `mesh.broadcast` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Broadcast the tensor on root to all devices in each respective group. The operation broadcasts along mesh axes mesh_axes. The root device specifies the in-group multi-index that is broadcast to all other devices in the group.

Example:

mesh.mesh @mesh0(shape = 2x2)

%1 = mesh.broadcast %0 on @mesh0
  mesh_axes = [0]
  root = [0]
  : (tensor<2xi8>) -> tensor<2xi8>

Input:

                 +-------+-------+                   | broadcast
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                 +-------+-------+                   ↓
device (1, 0) -> |       |       | <- device (1, 1) 
                 +-------+-------+

Output:

                 +-------+-------+
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
                 +-------+-------+
device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.gather (mesh::GatherOp) 

Gather over a device mesh.

Syntax:

operation ::= `mesh.gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `gather_axis` `=` $gather_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Gathers on device root along the gather_axis tensor axis. root specifies the coordinates of a device along mesh_axes. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
  gather_axis = 1 root = [1]
  : (tensor<2x2xi8>) -> tensor<2x4xi8>

Input:

                  gather tensor
                  axis 1
                  ------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

+-------------+
|  1  2  5  6 | <- devices (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+

Devices (0, 0) and (1, 0) have undefined result.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.mesh (mesh::MeshOp) 

Description of a device/process mesh.

Syntax:

operation ::= `mesh.mesh` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
              attr-dict

The mesh.mesh operation is a symbol operation that identifies a specific mesh. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the mesh. This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. shape: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example 2x?x4.

Example:

// A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12 
mesh.mesh @mesh0(shape = 4x8x12)

// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.mesh @mesh1(shape = 4x?)

// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.mesh @mesh2(shape = ?x4)

// A device mesh with 2 axes, the number of devices along both axes
// is unknown
mesh.mesh @mesh3(shape = ?x?)

Interfaces: Symbol

Attributes: 

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
shape::mlir::DenseI64ArrayAttri64 dense array attribute

mesh.mesh_shape (mesh::MeshShapeOp) 

Get the shape of the mesh.

Syntax:

operation ::= `mesh.mesh_shape` $mesh (`axes` `=` $axes^)?
              attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

mesh.process_linear_index (mesh::ProcessLinearIndexOp) 

Get the linear index of the current device.

Syntax:

operation ::= `mesh.process_linear_index` `on` $mesh attr-dict `:` type($result)

Example:

%idx = mesh.process_linear_index on @mesh : index

if @mesh has shape (10, 20, 30), a device with multi index (1, 2, 3) will have linear index 3 + 30*2 + 20*30*1.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute

Results: 

ResultDescription
resultindex

mesh.process_multi_index (mesh::ProcessMultiIndexOp) 

Get the multi index of current device along specified mesh axes.

Syntax:

operation ::= `mesh.process_multi_index` `on` $mesh (`axes` `=` $axes^)?
              attr-dict `:` type($result)

It is used in the SPMD format of IR. The axes mush be non-negative and less than the total number of mesh axes. If the axes are empty then get the index along all axes.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

mesh.recv (mesh::RecvOp) 

Send over a device mesh.

Syntax:

operation ::= `mesh.recv` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
              attr-dict `:` functional-type(operands, results)

Receive from a device within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
source::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
source_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.reduce (mesh::ReduceOp) 

Reduce over a device mesh.

Syntax:

operation ::= `mesh.reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`reduction` `=` $reduction^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Reduces on device root within each device group. root specifies the coordinates of a device along mesh_axes. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
  reduction = <max> root = [2, 3]
  : (tensor<3x4xf32>) -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.reduce_scatter (mesh::ReduceScatterOp) 

Reduce-scatter over a device mesh.

Syntax:

operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`reduction` `=` $reduction^)?
              `scatter_axis` `=` $scatter_axis
              attr-dict `:` type($input) `->` type($result)

After the reduction, the result is scattered within each device group. The tensor is split along scatter_axis and the pieces distributed across the device group. Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
  reduction = <max> scatter_axis = 0
  : tensor<3x4xf32> -> tensor<1x4xf64>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |  1  2 |  5  6 |  | axis 0
                 |  3  4 |  7  8 |  ↓
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 |
                 | 11 12 | 15 16 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

+-------+
|  6  8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)
scatter_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.scatter (mesh::ScatterOp) 

Scatter over a device mesh.

Syntax:

operation ::= `mesh.scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `scatter_axis` `=` $scatter_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

For each device group split the input tensor on the root device along axis scatter_axis and scatter the parts across the group devices.

Example:

mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
  scatter_axis = 0
  root = [1]
  : (tensor<2x2xi8>) -> tensor<1x2xi8>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |       |       |  | axis 0
                 |       |       |  ↓
                 +-------+-------+
device (1, 0) -> |  1  2 |  5  6 |
                 |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 |
                 +-------+-------+ 
device (1, 0) -> |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
scatter_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.send (mesh::SendOp) 

Send over a device mesh.

Syntax:

operation ::= `mesh.send` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
              attr-dict `:` functional-type(operands, results)

Send from one device to another within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
destination::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
destination_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.shard (mesh::ShardOp) 

Annotate on how a tensor is sharded across a mesh.

Syntax:

operation ::= `mesh.shard` $src `to` $sharding
              (`annotate_for_users` $annotate_for_users^)?
              attr-dict `:` type($result)

The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has two operands and two optional attributes:

  1. input: This operand represents the tensor value that needs to be annotated for sharding.

  2. sharding: This attribute is type of MeshShardingType, which is the core data structure to represent distribution of a tensor on a mesh. it is typically defiend by an mesh.sharding operation.

  3. annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.

Example:

func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
  ...
}

func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
  %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
  ...
}

// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
  %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
  %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
  %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
  "op0"(%1) : ...
  "op1"(%2) : ...
  ...
}

The following usages are undefined:

func.func @annotate_on_same_result_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
  %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
  %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_result_same_value_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
  %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_operand_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @result_annotated_after_operand(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
  %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
  %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
  ...
}

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
annotate_for_users::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcranked tensor of any type values
shardingsharding definition

Results: 

ResultDescription
resultranked tensor of any type values

mesh.shard_shape (mesh::ShardShapeOp) 

Get the shard shape of a given process/device.

Syntax:

operation ::= `mesh.shard_shape` custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)

The device/process id is a linearized id of the device/process in the mesh. This operation might be used during spmdization when the shard shape depends on (non-constant) values used in mesh.sharding.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
shape::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
shardingsharding definition
deviceindex

Results: 

ResultDescription
resultvariadic of index

mesh.sharding (mesh::ShardingOp) 

Define a sharding of a tensor.

Syntax:

operation ::= `mesh.sharding` $mesh
              `split_axes` `=` $split_axes
              (`partial` `=` $partial_type $partial_axes^)?
              (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
              (`sharded_dims_sizes` `=` custom<DynamicIndexList>($dynamic_sharded_dims_sizes, $static_sharded_dims_sizes)^)?
              attr-dict `:` type($result)

The MeshSharding specifies how a tensor is sharded and distributed across the process mesh. It is typically used in a mesh.shard operation. The operation has the follwing attributes and operands:

  1. mesh: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to a mesh.mesh operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.

  3. [Optional] partial_axes: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by partial_type.

  4. [Optional] partial_type: indicates the reduction type of the possible all-reduce op. It has 4 possible values: generic: is not an allowed value inside a shard attribute.

  5. [Optional] Sizes of halos to be added for each sharded tensor dimension. halo_sizesis provided as a flattened 1d array of i64s, 2 values for each sharded dimension. halo_sizes = [1, 2] means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end. halo_sizes = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos. ? indicates dynamic halo sizes.

  6. [Optional] Sizes of sharded dimensions of each shard. sharded_dims_sizesis provided as a flattened 1d array of i64s: for each device of the device-mesh one value for each sharded tensor dimension. Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, sharded_dims_sizes = [16, 8, 16, 24] means that the first device of the device-mesh will get a shard of shape 16x8x32 and the second device will get a shard of shape 16x24x32. ? indicates dynamic shard dimensions.

halo_sizes and sharded_dims_sizes are mutually exclusive.

Examples:

mesh.mesh @mesh0(shape = 2x2x4)
mesh.mesh @mesh1d_4(shape = 4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = mesh.sharding @mesh0 split_axes = [[]]

// The tensor is sharded on the first dimension along axis 0 of @mesh0
%sharding1 = mesh.sharding @mesh0 split_axes = [[0]]

// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
%sharding2 = mesh.sharding @mesh0 split_axes = [[0] split_axes = []] partial = sum[1]

// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
%sharding3 = mesh.sharding @mesh0 split_axes = [[0]] partial = max[1]

// Could be used for a mesh.shard op
%sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>

// The tensor is sharded on its first dimension along axis 0 of @mesh0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>

// The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = mesh.sharding @mesh1d_4 split_axes = [[] split_axes = [0]] sharded_dims_sizes = [2, 3, 4, 5]
%sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
split_axes::mlir::mesh::MeshAxesArrayAttr
partial_axes::mlir::DenseI16ArrayAttri16 dense array attribute
partial_type::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)
static_sharded_dims_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
static_halo_sizes::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
dynamic_sharded_dims_sizesvariadic of 64-bit signless integer
dynamic_halo_sizesvariadic of 64-bit signless integer

Results: 

ResultDescription
resultsharding definition

mesh.shift (mesh::ShiftOp) 

Shift over a device mesh.

Syntax:

operation ::= `mesh.shift` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `shift_axis` `=` $shift_axis
              `offset` `=` $offset
              (`rotate` $rotate^)?
              attr-dict `:` type($input) `->` type($result)

Within each device group shift along mesh axis shift_axis by an offset offset. The result on devices that do not have a corresponding source is undefined. shift_axis must be one of mesh_axes. If the rotate attribute is present, instead of a shift a rotation is done.

Example:

mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
  shift_axis = 1 offset = 2 rotate
  : tensor<2xi8> -> tensor<2xi8>

Input:

mesh axis 1
----------->

+----+----+----+----+
|  1 |  2 |  3 |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+

Result:

+----+----+----+----+
|  3 |  4 |  1 |  2 |
+----+----+----+----+
|  7 |  8 |  5 |  6 |
+----+----+----+----+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
shift_axis::mlir::IntegerAttrindex attribute
offset::mlir::IntegerAttr64-bit signless integer attribute
rotate::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.update_halo (mesh::UpdateHaloOp) 

Update halo data.

Syntax:

operation ::= `mesh.update_halo` $input `on` $mesh
              `split_axes` `=` $split_axes
              (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
              (`target_halo_sizes` `=` $target_halo_sizes^)?
              attr-dict `:` type($input)

This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.

Assumes all devices hold tensors with same-sized halo data as specified by dynamic/static_halo_sizes.

split_axes specifies for each tensor axis along which mesh axes its halo data is updated.

Optionally resizes to new halo sizes target_halo_sizes.

Interfaces: SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
split_axes::mlir::mesh::MeshAxesArrayAttr
static_halo_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
target_halo_sizes::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.memref of any type values
dynamic_halo_sizesvariadic of 64-bit signless integer

Attributes 

MeshAxesArrayAttr 

Syntax:

#mesh.axisarray<
  ::llvm::ArrayRef<MeshAxesAttr>   # axes
>

Parameters: 

ParameterC++ typeDescription
axes::llvm::ArrayRef<MeshAxesAttr>

ReductionKindAttr 

Reduction of an iterator/mesh dimension.

Syntax:

#mesh.partial<
  ::mlir::mesh::ReductionKind   # value
>

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)

Parameters: 

ParameterC++ typeDescription
value::mlir::mesh::ReductionKindan enum of type ReductionKind