MLIR

Multi-Level IR Compiler Framework

'mesh' Dialect

The mesh dialect contains a set of attributes, operations and interfaces that are useful for representing sharding and communication on a device mesh cluster.

Collective Communication Operations 

There are a number of operations in the Mesh dialect to facilitate communication between devices in a mesh. It is assumed that the user is familiar with collective operations. Wikipedia has a good explanation. The main addition is that the collectives in this dialect have mesh semantics.

Device groups 

The operation attributes mesh and mesh_axes specifies a list of device mesh axes that partition the devices into disjoint groups. The collective operation is performed between devices in the same group. Devices that have the same coordinates outside of axes mesh_axes are in the same group. A group is described by its multi-index along the axes outside of mesh_axes. For example if we have a device mesh of size 2x3x4x5 and the partition mesh axes list is [0, 1] then devices are partitioned into the groups { { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }. The device groups would be { (k, m) | 0<=k<4, 0<=m<5 }. Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. Device (1, 0, 2, 4) will be in another group. Some collective operations like all-to-all and all-gather care about the order of devices. The order of device in a device group is induced by the order of axes in mesh_axes. The axes are ordered from outer to inner. If we have an axis list [3, 1] then device (i, 1, k, 0) will precede both devices (i, 0, k, 1) and (i, 2, k, 0).

In-group Device 

Some operations like broadcast, scatter and send specify devices in each device-group. These devices are represented with their multi-index over the mesh axes that are not constant within a device group. These are the axes specified by mesh_axes attribute.

For Example on a 3D mesh an operation with mesh_axes = [0, 2] would specify an in-group device with (i, j). Then for each group with index g on the second axis, the in-group device would be (i, g, j).

Purity 

Collectives that involve the whole device group to perform a single operation are pure. The exceptions are send and recv.

There is an assumption that the execution is SPMD. Not only that each process runs the same program, but that at the point of execution of a collective operation, all processes are in a coherent state. All compiler transformations must be consistent. Collective operations in the IR that may correspond to the same runtime collective operation must be transformed in a consistent manner. For example if a collective operation is optimized out, than it must also not appear in any path of execution on any process.

Having the operations as Pure implies that if an interpreter is to execute the IR containing the mesh collectives, all processes would execute the same line when they reach a pure collective operation. This requirement stems from the need to be compatible with general optimization passes like dead code and common sub-expression elimination.

Operations 

source

mesh.all_gather (mesh::AllGatherOp) 

All-gather over a device mesh.

Syntax:

operation ::= `mesh.all_gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
              attr-dict `:` type($input) `->` type($result)

Gathers along the gather_axis tensor axis.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
  : tensor<2x2xi8> -> tensor<2x4xi8>

Input:

                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

gather tensor
axis 1
------------>
+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.all_reduce (mesh::AllReduceOp) 

All-reduce over a device mesh.

Syntax:

operation ::= `mesh.all_reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
              attr-dict `:` type($input) `->` type($result)

The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
  : tensor<3x4xf32> -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)

Operands: 

OperandDescription
inputranked tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.all_slice (mesh::AllSliceOp) 

All-slice over a device mesh. This is the inverse of all-gather.

Syntax:

operation ::= `mesh.all_slice` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
              attr-dict `:` type($input) `->` type($result)

Slice along the slice_axis tensor axis. This operation can be thought of as the inverse of all-gather. Technically, it is not required that all processes have the same input tensor. Each process will slice a piece of its local tensor based on its in-group device index. The operation does not communicate data between devices.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
  : tensor<2x4xi8> -> tensor<2x2xi8>

Input:

+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Result:

gather tensor
axis 1
------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
slice_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.all_to_all (mesh::AllToAllOp) 

All-to-all over a device mesh.

Syntax:

operation ::= `mesh.all_to_all` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `split_axis` `=` $split_axis
              `concat_axis` `=` $concat_axis
              attr-dict `:` type($input) `->` type($result)

Performs an all-to-all on tensor pieces split along split_axis. The resulting pieces are concatenated along concat_axis on ech device.

Example:

mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
  split_axis = 0 concat_axis = 0
  : tensor<3x2xi8> -> tensor<3x2xi8>

Input:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+  | split and concat along
| 11 12 | 21 22 | 31 32 |  | tensor axis 0
| 13 14 | 23 24 | 33 34 |  ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+

Result:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
split_axis::mlir::IntegerAttrindex attribute
concat_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.broadcast (mesh::BroadcastOp) 

Broadcast over a device mesh.

Syntax:

operation ::= `mesh.broadcast` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Broadcast the tensor on root to all devices in each respective group. The operation broadcasts along mesh axes mesh_axes. The root device specifies the in-group multi-index that is broadcast to all other devices in the group.

Example:

mesh.mesh @mesh0(shape = 2x2)

%1 = mesh.broadcast %0 on @mesh0
  mesh_axes = [0]
  root = [0]
  : (tensor<2xi8>) -> tensor<2xi8>

Input:

                 +-------+-------+                   | broadcast
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                 +-------+-------+                   ↓
device (1, 0) -> |       |       | <- device (1, 1) 
                 +-------+-------+

Output:

                 +-------+-------+
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
                 +-------+-------+
device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.gather (mesh::GatherOp) 

Gather over a device mesh.

Syntax:

operation ::= `mesh.gather` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `gather_axis` `=` $gather_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Gathers on device root along the gather_axis tensor axis. root specifies the coordinates of a device along mesh_axes. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior.

Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
  gather_axis = 1 root = [1]
  : (tensor<2x2xi8>) -> tensor<2x4xi8>

Input:

                  gather tensor
                  axis 1
                  ------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

+-------------+
|  1  2  5  6 | <- devices (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+

Devices (0, 0) and (1, 0) have undefined result.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

mesh.mesh (mesh::MeshOp) 

Description of a device/process mesh.

Syntax:

operation ::= `mesh.mesh` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
              attr-dict

The mesh.mesh operation is a symbol operation that identifies a specific mesh. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the mesh. This name serves as a symbolic reference to the mesh throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. shape: This attribute represents the shape of the device mesh. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example 2x?x4.

Example:

// A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12 
mesh.mesh @mesh0(shape = 4x8x12)

// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
mesh.mesh @mesh1(shape = 4x?)

// A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
mesh.mesh @mesh2(shape = ?x4)

// A device mesh with 2 axes, the number of devices along both axes
// is unknown
mesh.mesh @mesh3(shape = ?x?)

Interfaces: Symbol

Attributes: 

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
shape::mlir::DenseI64ArrayAttri64 dense array attribute

mesh.mesh_shape (mesh::MeshShapeOp) 

Get the shape of the mesh.

Syntax:

operation ::= `mesh.mesh_shape` $mesh (`axes` `=` $axes^)?
              attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

mesh.process_linear_index (mesh::ProcessLinearIndexOp) 

Get the linear index of the current device.

Syntax:

operation ::= `mesh.process_linear_index` `on` $mesh attr-dict `:` type($result)

Example:

%idx = mesh.process_linear_index on @mesh : index

if @mesh has shape (10, 20, 30), a device with multi index (1, 2, 3) will have linear index 3 + 30*2 + 20*30*1.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute

Results: 

ResultDescription
resultindex

mesh.process_multi_index (mesh::ProcessMultiIndexOp) 

Get the multi index of current device along specified mesh axes.

Syntax:

operation ::= `mesh.process_multi_index` `on` $mesh (`axes` `=` $axes^)?
              attr-dict `:` type($result)

It is used in the SPMD format of IR. The axes mush be non-negative and less than the total number of mesh axes. If the axes are empty then get the index along all axes.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

mesh.recv (mesh::RecvOp) 

Send over a device mesh.

Syntax:

operation ::= `mesh.recv` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
              attr-dict `:` functional-type(operands, results)

Receive from a device within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
source::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
source_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.reduce (mesh::ReduceOp) 

Reduce over a device mesh.

Syntax:

operation ::= `mesh.reduce` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`reduction` `=` $reduction^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Reduces on device root within each device group. root specifies the coordinates of a device along mesh_axes. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
  reduction = <max> root = [2, 3]
  : (tensor<3x4xf32>) -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.reduce_scatter (mesh::ReduceScatterOp) 

Reduce-scatter over a device mesh.

Syntax:

operation ::= `mesh.reduce_scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              (`reduction` `=` $reduction^)?
              `scatter_axis` `=` $scatter_axis
              attr-dict `:` type($input) `->` type($result)

After the reduction, the result is scattered within each device group. The tensor is split along scatter_axis and the pieces distributed across the device group. Example:

mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
  reduction = <max> scatter_axis = 0
  : tensor<3x4xf32> -> tensor<1x4xf64>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |  1  2 |  5  6 |  | axis 0
                 |  3  4 |  7  8 |  ↓
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 |
                 | 11 12 | 15 16 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

+-------+
|  6  8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::mesh::ReductionKindAttr
Reduction of an iterator/mesh dimension.

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)
scatter_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.scatter (mesh::ScatterOp) 

Scatter over a device mesh.

Syntax:

operation ::= `mesh.scatter` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `scatter_axis` `=` $scatter_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

For each device group split the input tensor on the root device along axis scatter_axis and scatter the parts across the group devices.

Example:

mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
  scatter_axis = 0
  root = [1]
  : (tensor<2x2xi8>) -> tensor<1x2xi8>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |       |       |  | axis 0
                 |       |       |  ↓
                 +-------+-------+
device (1, 0) -> |  1  2 |  5  6 |
                 |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 |
                 +-------+-------+ 
device (1, 0) -> |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
scatter_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.send (mesh::SendOp) 

Send over a device mesh.

Syntax:

operation ::= `mesh.send` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
              attr-dict `:` functional-type(operands, results)

Send from one device to another within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
destination::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
destination_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

mesh.shard (mesh::ShardOp) 

Annotate on how a tensor is sharded across a mesh.

Syntax:

operation ::= `mesh.shard` $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
              type($result)

The mesh.shard operation is designed to specify and guide the sharding behavior of a tensor value across a mesh topology. This operation has one operand and two attributes:

  1. input: This operand represents the tensor value that needs to be annotated for sharding.

  2. shard: This attribute is type of MeshSharding, which is the core data structure to represent distribution of a tensor on a mesh.

  3. annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.

Example:

func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  ...
}

func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  ...
}

// The first mesh.shard op applies to %arg0, the second mesh.shard op
// applies for the operand of op0, the third mesh.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
  %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
  "op0"(%1) : ...
  "op1"(%2) : ...
  ...
}

The following usages are undefined:

func.func @annotate_on_same_result_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_result_same_value_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
  %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_operand_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @result_annotated_after_operand(
    %arg0 : tensor<4x8xf32>) -> () {
  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
  %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
  ...
}

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
shard::mlir::mesh::MeshShardingAttr
Attribute that extends tensor type to distributed tensor type.
The MeshSharding attribute is used in a `mesh.shard` operation.
It specifies how a tensor is sharded and distributed across the process
mesh.
  1. mesh: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to a mesh.mesh operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.

  3. partial_axes: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by partial_type.

  4. partial_type: indicates the reduction type of the possible all-reduce op. It has 4 possible values: generic: is not an allowed value inside a shard attribute.

Example:

mesh.mesh @mesh0(shape = 2x2x4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
#mesh.shard&lt;@mesh0, [[]]&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0
#mesh.shard&lt;@mesh0, [[0]]&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
#mesh.shard&lt;@mesh0, [[0], []], partial = sum[1]&gt;

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
#mesh.shard&lt;@mesh0, [[0]], partial = max[1]&gt;

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to &lt;@mesh0, [[0]]&gt; : tensor&lt;4x8xf32&gt;

annotate_for_users::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcranked tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

mesh.shift (mesh::ShiftOp) 

Shift over a device mesh.

Syntax:

operation ::= `mesh.shift` $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
              `shift_axis` `=` $shift_axis
              `offset` `=` $offset
              (`rotate` $rotate^)?
              attr-dict `:` type($input) `->` type($result)

Within each device group shift along mesh axis shift_axis by an offset offset. The result on devices that do not have a corresponding source is undefined. shift_axis must be one of mesh_axes. If the rotate attribute is present, instead of a shift a rotation is done.

Example:

mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
  shift_axis = 1 offset = 2 rotate
  : tensor<2xi8> -> tensor<2xi8>

Input:

mesh axis 1
----------->

+----+----+----+----+
|  1 |  2 |  3 |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+

Result:

+----+----+----+----+
|  3 |  4 |  1 |  2 |
+----+----+----+----+
|  7 |  8 |  5 |  6 |
+----+----+----+----+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mesh::mlir::FlatSymbolRefAttrflat symbol reference attribute
mesh_axes::mlir::DenseI16ArrayAttri16 dense array attribute
shift_axis::mlir::IntegerAttrindex attribute
offset::mlir::IntegerAttr64-bit signless integer attribute
rotate::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

Attributes 

MeshShardingAttr 

Attribute that extends tensor type to distributed tensor type.

Syntax:

#mesh.shard<
  ::mlir::FlatSymbolRefAttr,   # mesh
  ::llvm::ArrayRef<MeshAxesAttr>,   # split_axes
  ::llvm::ArrayRef<MeshAxis>,   # partial_axes
  ::mlir::mesh::ReductionKind   # partial_type
>

The MeshSharding attribute is used in a mesh.shard operation. It specifies how a tensor is sharded and distributed across the process mesh.

  1. mesh: this attribute is a FlatSymbolRefAttr that refers to the device mesh where the distributed tensor is placed. The symbol must resolve to a mesh.mesh operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device mesh.

  3. partial_axes: if not empty, this signifies that the tensor is partial one along the specified mesh axes. An all-reduce should be applied to obtain the complete tensor, with reduction type being specified by partial_type.

  4. partial_type: indicates the reduction type of the possible all-reduce op. It has 4 possible values: generic: is not an allowed value inside a shard attribute.

Example:

mesh.mesh @mesh0(shape = 2x2x4)

// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
#mesh.shard<@mesh0, [[]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0
#mesh.shard<@mesh0, [[0]]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
#mesh.shard<@mesh0, [[0], []], partial = sum[1]>

// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
#mesh.shard<@mesh0, [[0]], partial = max[1]>

// Could be used in the attribute of mesh.shard op
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>

Parameters: 

ParameterC++ typeDescription
mesh::mlir::FlatSymbolRefAttrThe mesh on which tensors are sharded.
split_axes::llvm::ArrayRef<MeshAxesAttr>
partial_axes::llvm::ArrayRef<MeshAxis>
partial_type::mlir::mesh::ReductionKind

ReductionKindAttr 

Reduction of an iterator/mesh dimension.

Syntax:

#mesh.partial<
  ::mlir::mesh::ReductionKind   # value
>

Enum cases:

  • sum (Sum)
  • max (Max)
  • min (Min)
  • product (Product)
  • average (Average)
  • bitwise_and (BitwiseAnd)
  • bitwise_or (BitwiseOr)
  • bitwise_xor (BitwiseXor)
  • generic (Generic)

Parameters: 

ParameterC++ typeDescription
value::mlir::mesh::ReductionKindan enum of type ReductionKind