MLIR

Multi-Level IR Compiler Framework

'shard' Dialect

The ‘shard’ dialect defines a set of attributes, operations, and interfaces for working with tensor sharding and device communication.

It’s inspired by [GSPMD](General and Scalable Parallelization for ML Computation Graphs).

Originally, the dialect was called mesh, but it was renamed to better reflect what it actually does.

Collective Communication Operations 

The ‘shard’ dialect includes several collective operations that help coordinate communication between devices arranged in a grid.

If you’re not already familiar with collective operations, this Wikipedia article is a good starting point.

Unlike traditional collectives that are defined in terms of message-passing between explicit buffers on each process, the collectives in this dialect work at a higher level. They’re defined in terms of how data moves across the dimensions of a tensor, and the participating processes are inferred from how the tensor is sharded - not specified manually.

Device Groups 

Each collective operation runs within a group of devices. You define groups using the grid and grid_axes attributes, which describe how to slice the full device grid into smaller groups.

Devices that have the same coordinates outside the listed grid_axes belong to the same group.

Example: Say your device grid is shaped 2×3×4×5, and you set grid_axes = [0, 1]. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:

{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }

So the groups are identified by the coordinates (k, m), and devices like (1, 0, 2, 3) and (1, 1, 2, 3) are in the same group. But (1, 0, 2, 4) is in a different group.

For some collectives (like all-to-all), the order of devices in the group matters. The device order is based on the order of axes in grid_axes, from outermost to innermost.

Example: If grid_axes = [3, 1], then device (i, 1, k, 0) comes before (i, 0, k, 1) and (i, 2, k, 0).

In-group Devices 

Some operations (like broadcast, scatter, and send) refer to a specific device within each group. These in-group devices are identified using their coordinates over the axes listed in grid_axes.

Example: In a 3D grid with grid_axes = [0, 2], an in-group device is specified as (i, j). If a group is fixed at coordinate g on axis 1, then the full device index would be (i, g, j).

Purity and Execution Model 

Collective operations involve all devices in a group (e.g. all-gather, all-to-all) and are considered pure. Operations like send and recv are not collective and are not pure.

The execution model assumes SPMD (Single Program, Multiple Data):

  • Every process runs the same program.
  • At any collective operation, all processes are in sync.

This means compiler optimizations must treat collective ops carefully. For example, if a collective is removed during optimization, it must be removed from every path and every process that would have participated - otherwise, you’ll get undefined behavior at runtime.

Marking these ops as pure also helps with standard compiler passes like dead code elimination and common subexpression elimination. It ensures that when the program is executed, all devices hit the same line of code at the same time during collectives and so avoid dead-locks.

Operations 

source

shard.all_gather (shard::AllGatherOp) 

All-gather over a device grid.

Syntax:

operation ::= `shard.all_gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
              attr-dict `:` type($input) `->` type($result)

Gathers along the gather_axis tensor axis.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
  : tensor<2x2xi8> -> tensor<2x4xi8>

Input:

                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

gather tensor
axis 1
------------>
+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

shard.all_reduce (shard::AllReduceOp) 

All-reduce over a device grid.

Syntax:

operation ::= `shard.all_reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)?
              attr-dict `:` type($input) `->` type($result)

The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
  : tensor<3x4xf32> -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::shard::ReductionKindAttrReduction of an iterator/grid dimension.

Operands: 

OperandDescription
inputmemref of any type values or ranked tensor of any type values

Results: 

ResultDescription
resultmemref of any type values or ranked tensor of any type values

shard.all_slice (shard::AllSliceOp) 

All-slice over a device grid. This is the inverse of all-gather.

Syntax:

operation ::= `shard.all_slice` $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis
              attr-dict `:` type($input) `->` type($result)

Slice along the slice_axis tensor axis. This operation can be thought of as the inverse of all-gather. Technically, it is not required that all processes have the same input tensor. Each process will slice a piece of its local tensor based on its in-group device index. The operation does not communicate data between devices.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
  : tensor<2x4xi8> -> tensor<2x2xi8>

Input:

+-------------+
|  1  2  5  6 | <- devices (0, 0) and (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+

Result:

gather tensor
axis 1
------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
slice_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

shard.all_to_all (shard::AllToAllOp) 

All-to-all over a device grid.

Syntax:

operation ::= `shard.all_to_all` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `split_axis` `=` $split_axis
              `concat_axis` `=` $concat_axis
              attr-dict `:` type($input) `->` type($result)

Performs an all-to-all on tensor pieces split along split_axis. The resulting pieces are concatenated along concat_axis on ech device.

Example:

shard.grid @grid0(shape = 3)
...
%1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
  split_axis = 0 concat_axis = 0
  : tensor<3x2xi8> -> tensor<3x2xi8>

Input:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+  | split and concat along
| 11 12 | 21 22 | 31 32 |  | tensor axis 0
| 13 14 | 23 24 | 33 34 |  ↓
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+

Result:

 device  device  device
 (0)     (1)     (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
split_axis::mlir::IntegerAttrindex attribute
concat_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

shard.broadcast (shard::BroadcastOp) 

Broadcast over a device grid.

Syntax:

operation ::= `shard.broadcast` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Broadcast the tensor on root to all devices in each respective group. The operation broadcasts along grid axes grid_axes. The root device specifies the in-group multi-index that is broadcast to all other devices in the group.

Example:

shard.grid @grid0(shape = 2x2)

%1 = shard.broadcast %0 on @grid0
  grid_axes = [0]
  root = [0]
  : (tensor<2xi8>) -> tensor<2xi8>

Input:

                 +-------+-------+                   | broadcast
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                 +-------+-------+                   ↓
device (1, 0) -> |       |       | <- device (1, 1) 
                 +-------+-------+

Output:

                 +-------+-------+
device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
                 +-------+-------+
device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
                 +-------+-------+

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

shard.gather (shard::GatherOp) 

Gather over a device grid.

Syntax:

operation ::= `shard.gather` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `gather_axis` `=` $gather_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Gathers on device root along the gather_axis tensor axis. root specifies the coordinates of a device along grid_axes. It uniquely identifies the root device for each device group. The result tensor on non-root devices is undefined. Using it will result in undefined behavior.

Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.gather %0 on @grid0 grid_axes = [1]
  gather_axis = 1 root = [1]
  : (tensor<2x2xi8>) -> tensor<2x4xi8>

Input:

                  gather tensor
                  axis 1
                  ------------>
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
                 |  3  4 |  7  8 |
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
                 | 11 12 | 15 16 |
                 +-------+-------+

Result:

+-------------+
|  1  2  5  6 | <- devices (0, 1)
|  3  4  7  8 |
+-------------+
|  9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+

Devices (0, 0) and (1, 0) have undefined result.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
gather_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultnon-0-ranked.tensor of any type values

shard.get_sharding (shard::GetShardingOp) 

Get the sharding of the given tensor.

Syntax:

operation ::= `shard.get_sharding` $source attr-dict `:` type($source) `->` type($result)

This operation returns the sharding of the given tensor as a Sharding.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
sourceranked tensor of any type values

Results: 

ResultDescription
resultsharding definition

shard.grid (shard::GridOp) 

Description of a device/process grid.

Syntax:

operation ::= `shard.grid` $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
              attr-dict

The shard.grid operation is a symbol operation that identifies a specific grid. The operation has three attributes:

  1. sym_name: This attribute uniquely identifies the name of the grid. This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging.

  2. shape: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example 2x?x4.

Example:

// A device grid with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12 
shard.grid @grid0(shape = 4x8x12)

// A device grid with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
shard.grid @grid1(shape = 4x?)

// A device grid with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
shard.grid @grid2(shape = ?x4)

// A device grid with 2 axes, the number of devices along both axes
// is unknown
shard.grid @grid3(shape = ?x?)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), Symbol

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
shape::mlir::DenseI64ArrayAttri64 dense array attribute

shard.grid_shape (shard::GridShapeOp) 

Get the shape of the grid.

Syntax:

operation ::= `shard.grid_shape` $grid (`axes` `=` $axes^)?
              attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

shard.neighbors_linear_indices (shard::NeighborsLinearIndicesOp) 

For given grid index get the linear indices of the direct neighbor processes along the given split.

Syntax:

operation ::= `shard.neighbors_linear_indices` `on` $grid `[` $device `]`
              `split_axes` `=` $split_axes
              attr-dict `:` type(results)

Example:

shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index

The above returns two indices, 633 and 693, which correspond to the index of the previous process (1, 1, 3), and the next process (1, 3, 3) along the split axis 1`.

A negative value is returned if there is no neighbor in the respective direction along the given split_axes.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
split_axes::mlir::DenseI16ArrayAttri16 dense array attribute

Operands: 

OperandDescription
devicevariadic of index

Results: 

ResultDescription
neighbor_downindex
neighbor_upindex

shard.process_linear_index (shard::ProcessLinearIndexOp) 

Get the linear index of the current device.

Syntax:

operation ::= `shard.process_linear_index` `on` $grid attr-dict `:` type($result)

Example:

%idx = shard.process_linear_index on @grid : index

if @grid has shape (10, 20, 30), a device with multi index (1, 2, 3) will have linear index 3 + 30*2 + 20*30*1.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute

Results: 

ResultDescription
resultindex

shard.process_multi_index (shard::ProcessMultiIndexOp) 

Get the multi index of current device along specified grid axes.

Syntax:

operation ::= `shard.process_multi_index` `on` $grid (`axes` `=` $axes^)?
              attr-dict `:` type($result)

It is used in the SPMD format of IR. The axes mush be non-negative and less than the total number of grid axes. If the axes are empty then get the index along all axes.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
axes::mlir::DenseI16ArrayAttri16 dense array attribute

Results: 

ResultDescription
resultvariadic of index

shard.recv (shard::RecvOp) 

Send over a device grid.

Syntax:

operation ::= `shard.recv` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
              attr-dict `:` functional-type(operands, results)

Receive from a device within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
source::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
source_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

shard.reduce (shard::ReduceOp) 

Reduce over a device grid.

Syntax:

operation ::= `shard.reduce` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              (`reduction` `=` $reduction^)?
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

Reduces on device root within each device group. root specifies the coordinates of a device along grid_axes. It uniquely identifies the root device within its device group. The accumulation element type is specified by the result type and it does not need to match the input element type. The input element is converted to the result element type before performing the reduction.

Attributes: reduction: Indicates the reduction method.

Example:

%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
  reduction = <max> root = [2, 3]
  : (tensor<3x4xf32>) -> tensor<3x4xf64>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::shard::ReductionKindAttrReduction of an iterator/grid dimension.
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputranked tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

shard.reduce_scatter (shard::ReduceScatterOp) 

Reduce-scatter over a device grid.

Syntax:

operation ::= `shard.reduce_scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              (`reduction` `=` $reduction^)?
              `scatter_axis` `=` $scatter_axis
              attr-dict `:` type($input) `->` type($result)

After the reduction, the result is scattered within each device group. The tensor is split along scatter_axis and the pieces distributed across the device group. Example:

shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
  reduction = <max> scatter_axis = 0
  : tensor<3x4xf32> -> tensor<1x4xf64>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |  1  2 |  5  6 |  | axis 0
                 |  3  4 |  7  8 |  ↓
                 +-------+-------+
device (1, 0) -> |  9 10 | 13 14 |
                 | 11 12 | 15 16 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

+-------+
|  6  8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultRank

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
reduction::mlir::shard::ReductionKindAttrReduction of an iterator/grid dimension.
scatter_axis::mlir::IntegerAttrindex attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

shard.scatter (shard::ScatterOp) 

Scatter over a device grid.

Syntax:

operation ::= `shard.scatter` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `scatter_axis` `=` $scatter_axis
              `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
              attr-dict `:` functional-type(operands, results)

For each device group split the input tensor on the root device along axis scatter_axis and scatter the parts across the group devices.

Example:

shard.grid @grid0(shape = 2x2)
%1 = shard.scatter %0 on @grid0 grid_axes = [0]
  scatter_axis = 0
  root = [1]
  : (tensor<2x2xi8>) -> tensor<1x2xi8>

Input:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+  | scatter tensor
device (0, 0) -> |       |       |  | axis 0
                 |       |       |  ↓
                 +-------+-------+
device (1, 0) -> |  1  2 |  5  6 |
                 |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Result:

                          device
                          (0, 1)
                             ↓
                 +-------+-------+
device (0, 0) -> |  1  2 |  5  6 |
                 +-------+-------+ 
device (1, 0) -> |  3  4 |  7  8 |
                 +-------+-------+
                            ↑
                          device
                          (1, 1)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
scatter_axis::mlir::IntegerAttrindex attribute
root::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
root_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

shard.send (shard::SendOp) 

Send over a device grid.

Syntax:

operation ::= `shard.send` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
              attr-dict `:` functional-type(operands, results)

Send from one device to another within a device group.

Interfaces: OpAsmOpInterface, SymbolUserOpInterface

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
destination::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values
destination_dynamicvariadic of index

Results: 

ResultDescription
resultranked tensor of any type values

shard.shard (shard::ShardOp) 

Annotate on how a tensor is sharded across a shard.

Syntax:

operation ::= `shard.shard` $src `to` $sharding
              (`annotate_for_users` $annotate_for_users^)?
              attr-dict `:` type($result)

The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:

  1. input: This operand represents the tensor value that needs to be annotated for sharding.

  2. sharding: This attribute is type of ShardingType, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by an shard.sharding operation.

  3. annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.

Example:

func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
  ...
}

func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
  %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
  %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
  %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
  ...
}

// The first shard.shard op applies to %arg0, the second shard.shard op
// applies for the operand of op0, the third shard.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
  %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
  %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
  %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
  "op0"(%1) : ...
  "op1"(%2) : ...
  ...
}

The following usages are undefined:

func.func @annotate_on_same_result_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
  %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
  %1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_result_same_value_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
  %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
  ...
}

func.func @annotate_on_same_operand_with_different_sharding(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
  ...
}

func.func @result_annotated_after_operand(
    %arg0 : tensor<4x8xf32>) -> () {
  %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
  %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
  %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
  %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
  ...
}

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
annotate_for_users::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
srcranked tensor of any type values
shardingsharding definition

Results: 

ResultDescription
resultranked tensor of any type values

shard.shard_shape (shard::ShardShapeOp) 

Get the shard shape for a given process/device.

Syntax:

operation ::= `shard.shard_shape` `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
              `sharding` `=` $sharding
              `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
              attr-dict `:` type(results)

The device/process id is a multi-index of the device/process in the shard. This operation might be used during partition when the shard shape depends on (non-constant) values used in shard.sharding.

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
dims::mlir::DenseI64ArrayAttri64 dense array attribute
device::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
dims_dynamicvariadic of index
shardingsharding definition
device_dynamicvariadic of index

Results: 

ResultDescription
resultvariadic of index

shard.sharding (shard::ShardingOp) 

Define a sharding of a tensor.

Syntax:

operation ::= `shard.sharding` $grid
              `split_axes` `=` $split_axes
              (`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
              (`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
              attr-dict `:` type($result)

The Sharding specifies how a tensor is sharded and distributed across the process shard. It is typically used in a shard.shard operation. The operation has the following attributes and operands:

  1. grid: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to a shard.grid operation.

  2. split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is the rank of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device grid.

  3. [Optional] Sizes of halos to be added for each sharded tensor dimension. halo_sizes is provided as a flattened 1d array of i64s, 2 values for each sharded dimension. halo_sizes = [1, 2] means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end. halo_sizes = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos. ? indicates dynamic halo sizes.

  4. [Optional] Offsets for each shard and sharded tensor dimension. sharded_dims_offsets is provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that position i has the exclusive prefix sum for shard i, and since only contiguous sharding is supported, its inclusive prefix sum is at position ‘i+1’.

Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, sharded_dims_offsets = [0, 24, 32, 0, 20, 32] means that the first device of the device-grid will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32. ? indicates dynamic shard dimensions.

halo_sizes and sharded_dims_offsets are mutually exclusive.

Examples:

shard.grid @grid0(shape = 2x2x4)
shard.grid @grid1d_4(shape = 4)

// The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
%sharding0 = shard.sharding @grid0 split_axes = [[]]

// The tensor is sharded on the first dimension along axis 0 of @grid0
%sharding1 = shard.sharding @grid0 split_axes = [[0]]

// Could be used for a shard.shard op
%sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>

// The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>

// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
%sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
split_axes::mlir::shard::GridAxesArrayAttr
static_sharded_dims_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
static_halo_sizes::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
dynamic_sharded_dims_offsetsvariadic of 64-bit signless integer
dynamic_halo_sizesvariadic of 64-bit signless integer

Results: 

ResultDescription
resultsharding definition

shard.shift (shard::ShiftOp) 

Shift over a device grid.

Syntax:

operation ::= `shard.shift` $input `on` $grid (`grid_axes` `=` $grid_axes^)?
              `shift_axis` `=` $shift_axis
              `offset` `=` $offset
              (`rotate` $rotate^)?
              attr-dict `:` type($input) `->` type($result)

Within each device group shift along grid axis shift_axis by an offset offset. The result on devices that do not have a corresponding source is undefined. shift_axis must be one of grid_axes. If the rotate attribute is present, instead of a shift a rotation is done.

Example:

shard.grid @grid0(shape = 2x4)
%1 = shard.shift on @grid0 grid_axes = [1]
  shift_axis = 1 offset = 2 rotate
  : tensor<2xi8> -> tensor<2xi8>

Input:

grid axis 1
----------->

+----+----+----+----+
|  1 |  2 |  3 |  4 |
+----+----+----+----+
|  5 |  6 |  7 |  8 |
+----+----+----+----+

Result:

+----+----+----+----+
|  3 |  4 |  1 |  2 |
+----+----+----+----+
|  7 |  8 |  5 |  6 |
+----+----+----+----+

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
grid_axes::mlir::DenseI16ArrayAttri16 dense array attribute
shift_axis::mlir::IntegerAttrindex attribute
offset::mlir::IntegerAttr64-bit signless integer attribute
rotate::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
inputnon-0-ranked.tensor of any type values

Results: 

ResultDescription
resultranked tensor of any type values

shard.update_halo (shard::UpdateHaloOp) 

Update halo data.

Syntax:

operation ::= `shard.update_halo` $destination
              `on` $grid
              `split_axes` `=` $split_axes
              (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
              attr-dict `:` type($result)

This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.

Destination is supposed to be initialized with the local data (not halos).

Assumes all devices hold tensors with same-sized halo data as specified by source_halo_sizes/static_source_halo_sizes and destination_halo_sizes/static_destination_halo_sizes in source shard and destination/result shard.

split_axes specifies for each tensor axis along which grid axes its halo data is updated.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, DestinationStyleOpInterface, NoMemoryEffect (MemoryEffectOpInterface), SymbolUserOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
grid::mlir::FlatSymbolRefAttrflat symbol reference attribute
split_axes::mlir::shard::GridAxesArrayAttr
static_halo_sizes::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
destinationnon-0-ranked.memref of any type values or non-0-ranked.tensor of any type values
halo_sizesvariadic of 64-bit signless integer

Results: 

ResultDescription
resultnon-0-ranked.memref of any type values or non-0-ranked.tensor of any type values

Attributes 

GridAxesArrayAttr 

Syntax:

#shard.axisarray<
  ::llvm::ArrayRef<GridAxesAttr>   # axes
>

Parameters: 

ParameterC++ typeDescription
axes::llvm::ArrayRef<GridAxesAttr>

ReductionKindAttr 

Reduction of an iterator/grid dimension.

Syntax:

#shard.partial<
  ::mlir::shard::ReductionKind   # value
>

Parameters: 

ParameterC++ typeDescription
value::mlir::shard::ReductionKindan enum of type ReductionKind